|
import random |
|
import sys |
|
from random import randint |
|
|
|
import pandas as pd |
|
import plotly.graph_objects as go |
|
from anndata import AnnData |
|
|
|
|
|
sys.path.append('/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/transforna/') |
|
from src import (Results_Handler, correct_labels, load, predict_transforna, |
|
predict_transforna_all_models,get_fused_seqs) |
|
|
|
|
|
def get_mc_sc(infer_df,sequences,sub_classes_used_for_training,sc_to_mc_mapper_dict,ood_flag = False): |
|
|
|
infered_seqs = infer_df.loc[sequences] |
|
sc_classes_df = infered_seqs['subclass_name'].str.split(';',expand=True) |
|
|
|
sc_classes_df = sc_classes_df[~sc_classes_df.isnull().all(axis=1)] |
|
|
|
if ood_flag: |
|
sub_classes_used_for_training_plus_neighbors = [] |
|
|
|
for sub_class in sub_classes_used_for_training: |
|
sub_classes_used_for_training_plus_neighbors.append(sub_class) |
|
if 'bin' in sub_class: |
|
bin_num = int(sub_class.split('_bin-')[1]) |
|
if bin_num > 0: |
|
sub_classes_used_for_training_plus_neighbors.append(f'{sub_class.split("_bin-")[0]}_bin-{bin_num-1}') |
|
sub_classes_used_for_training_plus_neighbors.append(f'{sub_class.split("_bin-")[0]}_bin-{bin_num+1}') |
|
if 'tR' in sub_class: |
|
|
|
first_part = sub_class.split('-')[0] |
|
second_part = sub_class.split('__')[1] |
|
|
|
sub_classes_used_for_training_plus_neighbors += [sc for sc in sc_to_mc_mapper_dict.keys() if first_part in sc and second_part in sc] |
|
sub_classes_used_for_training_plus_neighbors = list(set(sub_classes_used_for_training_plus_neighbors)) |
|
mask = sc_classes_df.applymap(lambda x: True if (x not in sub_classes_used_for_training_plus_neighbors and 'hypermapper' not in x)\ |
|
or pd.isnull(x) else False) |
|
|
|
else: |
|
mask = sc_classes_df.applymap(lambda x: True if x in sub_classes_used_for_training or pd.isnull(x) else False) |
|
|
|
|
|
if mask.apply(lambda x: all(x.tolist()), axis=1).sum() == 0: |
|
|
|
import logging |
|
log_ = logging.getLogger(__name__) |
|
log_.error('None of the sub classes used for training are in the sequences') |
|
raise Exception('None of the sub classes used for training are in the sequences') |
|
|
|
|
|
sc_classes_df = sc_classes_df[mask.apply(lambda x: all(x.tolist()), axis=1)] |
|
|
|
mc_classes_df = sc_classes_df.applymap(lambda x: sc_to_mc_mapper_dict[x] if x in sc_to_mc_mapper_dict else 'not_found') |
|
|
|
|
|
|
|
mc_classes_df = mc_classes_df[mc_classes_df.apply(lambda x: 'not_found' not in x.tolist() ,axis=1)] |
|
|
|
mc_classes_df = mc_classes_df[~mc_classes_df[0].str.contains(';')] |
|
|
|
sc_classes_df = sc_classes_df.loc[mc_classes_df.index] |
|
mc_classes_df = mc_classes_df.loc[sc_classes_df.index] |
|
return mc_classes_df,sc_classes_df |
|
|
|
def plot_confusion_false_novel(df,sc_df,mc_df,save_figs:bool=False): |
|
|
|
curr_sc_classes_df = sc_df.loc[[i for i in df.index if i in sc_df.index]] |
|
curr_mc_classes_df = mc_df.loc[[i for i in df.index if i in mc_df.index]] |
|
|
|
df = df.assign(predicted_mc_labels=df.apply(lambda x: sc_to_mc_mapper_dict[x['predicted_sc_labels']] if x['predicted_sc_labels'] in sc_to_mc_mapper_dict else 'miRNA' if 'miR' in x['predicted_sc_labels'] else 'tRNA' if 'tRNA' in x['predicted_sc_labels'] else 'rRNA' if 'rRNA' in x['predicted_sc_labels'] else 'snRNA' if 'snRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'snoRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'SNOR' in x['predicted_sc_labels'] else 'protein_coding' if 'RPL37A' in x['predicted_sc_labels'] else 'lncRNA' if 'SNHG1' in x['predicted_sc_labels'] else x['predicted_sc_labels'], axis=1)) |
|
|
|
df = df.assign(actual_mc_labels=curr_mc_classes_df[0].values.tolist()) |
|
|
|
df = df.assign(actual_sc_labels=curr_sc_classes_df[0].values.tolist()) |
|
|
|
df = df.assign(mc_accuracy=df.apply(lambda x: 1 if x['actual_mc_labels'] == x['predicted_mc_labels'] else 0, axis=1)) |
|
df = df.assign(sc_accuracy=df.apply(lambda x: 1 if x['actual_sc_labels'] == x['predicted_sc_labels'] else 0, axis=1)) |
|
|
|
|
|
mc_confusion_matrix = df.groupby(['actual_mc_labels','predicted_mc_labels'])['mc_accuracy'].count().unstack() |
|
mc_confusion_matrix = mc_confusion_matrix.fillna(0) |
|
mc_confusion_matrix = mc_confusion_matrix.apply(lambda x: x/x.sum(), axis=1) |
|
mc_confusion_matrix = mc_confusion_matrix.applymap(lambda x: round(x,2)) |
|
|
|
other_col = [0]*mc_confusion_matrix.shape[0] |
|
for i in [i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()]: |
|
other_col += mc_confusion_matrix[i] |
|
mc_confusion_matrix['other'] = other_col |
|
|
|
mc_confusion_matrix.loc['other'] = [0]*mc_confusion_matrix.shape[1] |
|
|
|
mc_confusion_matrix = mc_confusion_matrix.drop([i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()], axis=1) |
|
|
|
fig = go.Figure(data=go.Heatmap( |
|
z=mc_confusion_matrix.values, |
|
x=mc_confusion_matrix.columns, |
|
y=mc_confusion_matrix.index, |
|
hoverongaps = False)) |
|
|
|
for i in range(len(mc_confusion_matrix.index)): |
|
for j in range(len(mc_confusion_matrix.columns)): |
|
fig.add_annotation(text=str(mc_confusion_matrix.values[i][j]), x=mc_confusion_matrix.columns[j], y=mc_confusion_matrix.index[i], |
|
showarrow=False, font_size=25, font_color='red') |
|
|
|
fig.update_layout(title_text='Confusion matrix based on mc classes for false novel sequences') |
|
|
|
fig.update_xaxes(title_text='Predicted mc class') |
|
fig.update_yaxes(title_text='Actual mc class') |
|
|
|
if save_figs: |
|
fig.write_image('transforna/bin/lc_figures/confusion_matrix_mc_classes_false_novel.png') |
|
|
|
|
|
def compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers = False,fig_prefix:str = '',save_figs:bool=False): |
|
font_size = 25 |
|
if fig_prefix == 'LC-familiar': |
|
font_size = 10 |
|
|
|
prediction_pd = prediction_pd.rename(columns={'Net-Label':'predicted_sc_labels'}) |
|
|
|
for model in prediction_pd['Model'].unique(): |
|
|
|
num_rows = sc_classes_df.shape[0] |
|
model_prediction_pd = prediction_pd[prediction_pd['Model'] == model] |
|
model_prediction_pd = model_prediction_pd.set_index('Sequence') |
|
|
|
model_prediction_pd = model_prediction_pd.loc[[i for i in model_prediction_pd.index if i in sc_classes_df.index]] |
|
|
|
try: |
|
|
|
embedds_path = f'/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/{model}/embedds' |
|
results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train']) |
|
except: |
|
embedds_path = f'/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/TransfoRNA_FULL/sub_class/Seq-Rev/embedds' |
|
results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=['train']) |
|
|
|
train_seqs = set(results.splits_df_dict['train_df']['RNA Sequences']['0'].values.tolist()) |
|
common_seqs = train_seqs.intersection(set(model_prediction_pd.index.tolist())) |
|
print(f'Number of common seqs between train_df and {model} is {len(common_seqs)}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
num_outliers = sum(model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold']) |
|
false_novel_df = model_prediction_pd[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold']] |
|
|
|
plot_confusion_false_novel(false_novel_df,sc_classes_df,mc_classes_df,save_figs) |
|
|
|
fig_outl = mc_classes_df.loc[false_novel_df.index][0].value_counts().plot.pie(autopct='%1.1f%%',figsize=(6, 6)) |
|
fig_outl.set_title(f'False Novel per MC for {model}: {num_outliers}') |
|
if save_figs: |
|
fig_outl.get_figure().savefig(f'transforna/bin/lc_figures/false_novel_mc_{model}.png') |
|
fig_outl.get_figure().clf() |
|
|
|
false_novel_sc_freq_df = sc_classes_df.loc[false_novel_df.index][0].value_counts().to_frame() |
|
|
|
|
|
|
|
false_novel_sc_freq_df['MC'] = false_novel_sc_freq_df.index.map(lambda x: sc_to_mc_mapper_dict[x]) |
|
|
|
fig_outl_sc = false_novel_sc_freq_df.groupby('MC')[0].sum().plot.pie(autopct='%1.1f%%',figsize=(6, 6)) |
|
fig_outl_sc.set_title(f'False novel: No. Unique sub classes per MC {model}: {num_outliers}') |
|
if save_figs: |
|
fig_outl_sc.get_figure().savefig(f'transforna/bin/lc_figures/{fig_prefix}_false_novel_sc_{model}.png') |
|
fig_outl_sc.get_figure().clf() |
|
|
|
if seperate_outliers: |
|
model_prediction_pd = model_prediction_pd[model_prediction_pd['NLD'] <= model_prediction_pd['Novelty Threshold']] |
|
else: |
|
|
|
model_prediction_pd.loc[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold'],'predicted_sc_labels'] = 'Outlier' |
|
model_prediction_pd.loc[model_prediction_pd['NLD'] > model_prediction_pd['Novelty Threshold'],'predicted_mc_labels'] = 'Outlier' |
|
sc_to_mc_mapper_dict['Outlier'] = 'Outlier' |
|
|
|
|
|
curr_sc_classes_df = sc_classes_df.loc[[i for i in model_prediction_pd.index if i in sc_classes_df.index]] |
|
curr_mc_classes_df = mc_classes_df.loc[[i for i in model_prediction_pd.index if i in mc_classes_df.index]] |
|
|
|
model_prediction_pd = model_prediction_pd.assign(predicted_mc_labels=model_prediction_pd.apply(lambda x: sc_to_mc_mapper_dict[x['predicted_sc_labels']] if x['predicted_sc_labels'] in sc_to_mc_mapper_dict else 'miRNA' if 'miR' in x['predicted_sc_labels'] else 'tRNA' if 'tRNA' in x['predicted_sc_labels'] else 'rRNA' if 'rRNA' in x['predicted_sc_labels'] else 'snRNA' if 'snRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'snoRNA' in x['predicted_sc_labels'] else 'snoRNA' if 'SNOR' in x['predicted_sc_labels'] else 'protein_coding' if 'RPL37A' in x['predicted_sc_labels'] else 'lncRNA' if 'SNHG1' in x['predicted_sc_labels'] else x['predicted_sc_labels'], axis=1)) |
|
|
|
model_prediction_pd = model_prediction_pd.assign(actual_mc_labels=curr_mc_classes_df[0].values.tolist()) |
|
|
|
model_prediction_pd = model_prediction_pd.assign(actual_sc_labels=curr_sc_classes_df[0].values.tolist()) |
|
|
|
model_prediction_pd['predicted_sc_labels'] = correct_labels(model_prediction_pd['predicted_sc_labels'],model_prediction_pd['actual_sc_labels'],sc_to_mc_mapper_dict) |
|
|
|
model_prediction_pd = model_prediction_pd.assign(mc_accuracy=model_prediction_pd.apply(lambda x: 1 if x['actual_mc_labels'] == x['predicted_mc_labels'] else 0, axis=1)) |
|
model_prediction_pd = model_prediction_pd.assign(sc_accuracy=model_prediction_pd.apply(lambda x: 1 if x['actual_sc_labels'] == x['predicted_sc_labels'] else 0, axis=1)) |
|
|
|
if not seperate_outliers: |
|
cols_to_save = ['actual_mc_labels','predicted_mc_labels','predicted_sc_labels','actual_sc_labels'] |
|
total_false_mc_predictions_df = model_prediction_pd[model_prediction_pd.actual_mc_labels != model_prediction_pd.predicted_mc_labels].loc[:,cols_to_save] |
|
|
|
total_false_mc_predictions_df['is_novel'] = model_prediction_pd.loc[total_false_mc_predictions_df.index]['NLD'] > model_prediction_pd.loc[total_false_mc_predictions_df.index]['Novelty Threshold'] |
|
|
|
total_false_mc_predictions_df.to_csv(f'transforna/bin/lc_files/{fig_prefix}_total_false_mcs_w_out_{model}.csv') |
|
total_true_mc_predictions_df = model_prediction_pd[model_prediction_pd.actual_mc_labels == model_prediction_pd.predicted_mc_labels].loc[:,cols_to_save] |
|
|
|
total_true_mc_predictions_df['is_novel'] = model_prediction_pd.loc[total_true_mc_predictions_df.index]['NLD'] > model_prediction_pd.loc[total_true_mc_predictions_df.index]['Novelty Threshold'] |
|
|
|
total_true_mc_predictions_df.to_csv(f'transforna/bin/lc_files/{fig_prefix}_total_true_mcs_w_out_{model}.csv') |
|
|
|
print('Model: ', model) |
|
print('num_outliers: ', num_outliers) |
|
|
|
print('mc_accuracy: ', model_prediction_pd['mc_accuracy'].mean()) |
|
print('sc_accuracy: ', model_prediction_pd['sc_accuracy'].mean()) |
|
|
|
|
|
print('mc_balanced_accuracy: ', model_prediction_pd.groupby('actual_mc_labels')['mc_accuracy'].mean().mean()) |
|
print('sc_balanced_accuracy: ', model_prediction_pd.groupby('actual_sc_labels')['sc_accuracy'].mean().mean()) |
|
|
|
|
|
mc_confusion_matrix = model_prediction_pd.groupby(['actual_mc_labels','predicted_mc_labels'])['mc_accuracy'].count().unstack() |
|
mc_confusion_matrix = mc_confusion_matrix.fillna(0) |
|
mc_confusion_matrix = mc_confusion_matrix.apply(lambda x: x/x.sum(), axis=1) |
|
mc_confusion_matrix = mc_confusion_matrix.applymap(lambda x: round(x,4)) |
|
|
|
other_col = [0]*mc_confusion_matrix.shape[0] |
|
for i in [i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()]: |
|
other_col += mc_confusion_matrix[i] |
|
mc_confusion_matrix['other'] = other_col |
|
|
|
mc_confusion_matrix.loc['other'] = [0]*mc_confusion_matrix.shape[1] |
|
|
|
mc_confusion_matrix = mc_confusion_matrix.drop([i for i in mc_confusion_matrix.columns if i not in mc_confusion_matrix.index.tolist()], axis=1) |
|
|
|
|
|
fig = go.Figure(data=go.Heatmap( |
|
z=mc_confusion_matrix.values, |
|
x=mc_confusion_matrix.columns, |
|
y=mc_confusion_matrix.index, |
|
colorscale='Blues', |
|
hoverongaps = False)) |
|
|
|
for i in range(len(mc_confusion_matrix.index)): |
|
for j in range(len(mc_confusion_matrix.columns)): |
|
fig.add_annotation(text=str(round(mc_confusion_matrix.values[i][j],2)), x=mc_confusion_matrix.columns[j], y=mc_confusion_matrix.index[i], |
|
showarrow=False, font_size=font_size, font_color='black') |
|
|
|
fig.update_layout( |
|
title='Confusion matrix for mc classes - ' + model + ' - ' + 'mc B. Acc: ' + str(round(model_prediction_pd.groupby('actual_mc_labels')['mc_accuracy'].mean().mean(),2)) \ |
|
+ ' - ' + 'sc B. Acc: ' + str(round(model_prediction_pd.groupby('actual_sc_labels')['sc_accuracy'].mean().mean(),2)) + '<br>' + \ |
|
'percent false novel: ' + str(round(num_outliers/num_rows,2)), |
|
xaxis_nticks=36) |
|
|
|
fig.update_xaxes(title_text='Predicted mc class') |
|
fig.update_yaxes(title_text='Actual mc class') |
|
if save_figs: |
|
|
|
if seperate_outliers: |
|
fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_no_out_' + model + '.png') |
|
|
|
fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_no_out_' + model + '.svg') |
|
else: |
|
fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_outliers_' + model + '.png') |
|
|
|
fig.write_image(f'transforna/bin/lc_figures/{fig_prefix}_LC_confusion_matrix_mc_outliers_' + model + '.svg') |
|
print('\n') |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
mapping_dict_path = '/media/ftp_share/hbdx/data_for_upload/TransfoRNA//data/subclass_to_annotation.json' |
|
LC_path = '/media/ftp_share/hbdx/annotation/feature_annotation/ANNOTATION/HBDxBase_annotation/TransfoRNA/compare_binning_strategies/v05/2024-04-19__230126_LC_DI_HB_GEL_v23.01.00/sRNA_anno_aggregated_on_seq.csv' |
|
path_to_models = '/nfs/home/yat_ldap/VS_Projects/TransfoRNA-Framework/models/tcga/' |
|
|
|
trained_on = 'full' |
|
save_figs = True |
|
|
|
infer_aa = infer_relaxed_mirna = infer_hico = infer_ood = infer_other_affixes = infer_random = infer_fused = infer_na = infer_loco = False |
|
|
|
split = 'infer_hico' |
|
print(f'Running inference for {split}') |
|
if split == 'infer_aa': |
|
infer_aa = True |
|
elif split == 'infer_relaxed_mirna': |
|
infer_relaxed_mirna = True |
|
elif split == 'infer_hico': |
|
infer_hico = True |
|
elif split == 'infer_ood': |
|
infer_ood = True |
|
elif split == 'infer_other_affixes': |
|
infer_other_affixes = True |
|
elif split == 'infer_random': |
|
infer_random = True |
|
elif split == 'infer_fused': |
|
infer_fused = True |
|
elif split == 'infer_na': |
|
infer_na = True |
|
elif split == 'infer_loco': |
|
infer_loco = True |
|
|
|
|
|
|
|
if sum([infer_aa,infer_relaxed_mirna,infer_hico,infer_ood,infer_other_affixes,infer_random,infer_fused,infer_na,infer_loco]) != 1: |
|
raise Exception('Only one of infer_aa or infer_relaxed_mirna or infer_normal or infer_ood or infer_hico or infer_other_affixes or infer_random or infer_fused or infer_na should be true') |
|
|
|
|
|
if infer_aa: |
|
fig_prefix = '5\'A-affixes' |
|
elif infer_other_affixes: |
|
fig_prefix = 'other_affixes' |
|
elif infer_relaxed_mirna: |
|
fig_prefix = 'Relaxed-miRNA' |
|
elif infer_hico: |
|
fig_prefix = 'LC-familiar' |
|
elif infer_ood: |
|
fig_prefix = 'LC-novel' |
|
elif infer_random: |
|
fig_prefix = 'Random' |
|
elif infer_fused: |
|
fig_prefix = 'Fused' |
|
elif infer_na: |
|
fig_prefix = 'NA' |
|
elif infer_loco: |
|
fig_prefix = 'LOCO' |
|
|
|
infer_df = load(LC_path) |
|
if isinstance(infer_df,AnnData): |
|
infer_df = infer_df.var |
|
infer_df.set_index('sequence',inplace=True) |
|
sc_to_mc_mapper_dict = load(mapping_dict_path) |
|
|
|
hico_seqs = infer_df.index[infer_df['hico']].tolist() |
|
art_affix_seqs = infer_df[~infer_df['five_prime_adapter_filter']].index.tolist() |
|
|
|
if infer_hico: |
|
hico_seqs = hico_seqs |
|
|
|
if infer_aa: |
|
hico_seqs = art_affix_seqs |
|
|
|
if infer_other_affixes: |
|
hico_seqs = infer_df[~infer_df['hbdx_spikein_affix_filter']].index.tolist() |
|
|
|
if infer_na: |
|
hico_seqs = infer_df[infer_df.subclass_name == 'no_annotation'].index.tolist() |
|
|
|
if infer_loco: |
|
hico_seqs = infer_df[~infer_df['hico']][infer_df.subclass_name != 'no_annotation'].index.tolist() |
|
|
|
|
|
if infer_relaxed_mirna: |
|
|
|
mirnas_seqs = infer_df[infer_df.subclass_name.str.contains('miR') | infer_df.subclass_name.str.contains('let')][~infer_df.subclass_name.str.contains(';')].index.tolist() |
|
|
|
hico_seqs = list(set(mirnas_seqs).difference(set(hico_seqs))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if infer_random: |
|
|
|
random_seqs = [] |
|
while len(random_seqs) < 200: |
|
random_seq = ''.join(random.choices(['A','C','G','T'], k=randint(18,30))) |
|
if random_seq not in random_seqs: |
|
random_seqs.append(random_seq) |
|
hico_seqs = random_seqs |
|
|
|
if infer_fused: |
|
hico_seqs = get_fused_seqs(hico_seqs,num_sequences=200) |
|
|
|
|
|
|
|
hico_seqs = [seq for seq in hico_seqs if len(seq) <= 30] |
|
|
|
import os |
|
os.environ["CUDA_VISIBLE_DEVICES"] = '1' |
|
|
|
|
|
prediction_pd = predict_transforna_all_models(hico_seqs,trained_on=trained_on,path_to_models=path_to_models) |
|
prediction_pd['split'] = fig_prefix |
|
|
|
if not infer_ood and not infer_relaxed_mirna and not infer_hico: |
|
prediction_pd.to_csv(f'transforna/bin/lc_files/{fig_prefix}_lev_dist_df.csv') |
|
if infer_aa or infer_other_affixes or infer_random or infer_fused: |
|
for model in prediction_pd.Model.unique(): |
|
num_non_novel = sum(prediction_pd[prediction_pd.Model == model]['Is Familiar?']) |
|
print(f'Number of non novel sequences for {model} is {num_non_novel}') |
|
print(f'Percent non novel for {model} is {num_non_novel/len(prediction_pd[prediction_pd.Model == model])}, the lower the better') |
|
|
|
else: |
|
if infer_na or infer_loco: |
|
|
|
for model in prediction_pd.Model.unique(): |
|
num_non_novel = sum(prediction_pd[prediction_pd.Model == model]['Is Familiar?']) |
|
print(f'Number of non novel sequences for {model} is {num_non_novel}') |
|
print(f'Percent non novel for {model} is {num_non_novel/len(prediction_pd[prediction_pd.Model == model])}, the higher the better') |
|
print('\n') |
|
else: |
|
|
|
prediction_single_pd = predict_transforna(hico_seqs[0],model='Seq',logits_flag = True,trained_on=trained_on,path_to_models=path_to_models) |
|
sub_classes_used_for_training = prediction_single_pd.columns.tolist() |
|
|
|
|
|
mc_classes_df,sc_classes_df = get_mc_sc(infer_df,hico_seqs,sub_classes_used_for_training,sc_to_mc_mapper_dict,ood_flag=infer_ood) |
|
if infer_ood: |
|
for model in prediction_pd.Model.unique(): |
|
|
|
curr_prediction_pd = prediction_pd[prediction_pd['Sequence'].isin(sc_classes_df.index)] |
|
|
|
curr_prediction_pd = curr_prediction_pd[curr_prediction_pd.Model == model] |
|
num_seqs = curr_prediction_pd.shape[0] |
|
|
|
curr_prediction_pd = curr_prediction_pd[curr_prediction_pd['Is Familiar?']] |
|
|
|
curr_sc_classes_df = sc_classes_df[sc_classes_df.index.isin(curr_prediction_pd['Sequence'].values)] |
|
|
|
curr_prediction_pd['Net-Label'] = correct_labels(curr_prediction_pd['Net-Label'].values,curr_sc_classes_df[0].values,sc_to_mc_mapper_dict) |
|
|
|
curr_prediction_pd = curr_prediction_pd[curr_prediction_pd['Net-Label'].values != curr_sc_classes_df[0].values] |
|
num_non_novel = len(curr_prediction_pd) |
|
print(f'Number of non novel sequences for {model} is {num_non_novel}') |
|
print(f'Percent non novel for {model} is {num_non_novel/num_seqs}, the lower the better') |
|
print('\n') |
|
else: |
|
|
|
|
|
|
|
compute_accuracy(prediction_pd,sc_classes_df,mc_classes_df,seperate_outliers=True,fig_prefix = fig_prefix,save_figs=save_figs) |
|
|
|
if infer_ood or infer_relaxed_mirna or infer_hico: |
|
prediction_pd = prediction_pd[prediction_pd['Sequence'].isin(sc_classes_df.index)] |
|
|
|
prediction_pd.to_csv(f'transforna/bin/lc_files/{fig_prefix}_lev_dist_df.csv') |
|
|
|
|
|
|
|
|