|
|
|
|
|
|
|
import json |
|
import logging |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
from imblearn.under_sampling import RandomUnderSampler |
|
from scipy.stats import entropy |
|
from sklearn.linear_model import LogisticRegression |
|
from sklearn.model_selection import train_test_split |
|
|
|
from ..utils.file import save |
|
from ..utils.tcga_post_analysis_utils import Results_Handler |
|
from .utlis import compute_prc, compute_roc |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def entropy_clf(results,random_state:int=1): |
|
|
|
art_affix_ent = results.splits_df_dict["artificial_affix_df"]["Entropy"]["0"].values |
|
|
|
if results.trained_on == 'id': |
|
test_ent = results.splits_df_dict["test_df"]["Entropy"]["0"].values |
|
else: |
|
test_ent = results.splits_df_dict["train_df"]["Entropy"]["0"].values[:int(0.25*len(results.splits_df_dict["train_df"]))] |
|
ent_x = np.concatenate((art_affix_ent,test_ent)) |
|
ent_labels = np.concatenate((np.zeros(art_affix_ent.shape),np.ones(test_ent.shape))) |
|
trainX, testX, trainy, testy = train_test_split(ent_x, ent_labels, stratify=ent_labels,test_size=0.9, random_state=random_state) |
|
|
|
model = LogisticRegression(solver='lbfgs',class_weight='balanced') |
|
|
|
model.fit(trainX.reshape(-1, 1), trainy) |
|
|
|
undersample = RandomUnderSampler(sampling_strategy='majority') |
|
testX,testy = undersample.fit_resample(testX.reshape(-1,1),testy) |
|
|
|
|
|
lr_probs = model.predict_proba(testX) |
|
|
|
lr_probs = lr_probs[:, 1] |
|
yhat = model.predict(testX) |
|
return testy,lr_probs,yhat,model |
|
|
|
def plot_entropy(results): |
|
|
|
entropies_list = [] |
|
test_idx = 0 |
|
for split_idx,split in enumerate(results.splits): |
|
entropies_list.append(results.splits_df_dict[f"{split}_df"]["Entropy"]["0"].values) |
|
if split == 'test': |
|
test_idx = split_idx |
|
|
|
|
|
bx = plt.boxplot(entropies_list) |
|
plt.title("Entropy Distribution") |
|
plt.ylabel("Entropy") |
|
plt.xticks(np.arange(len(results.splits))+1,results.splits) |
|
if results.save_results: |
|
plt.savefig(f"{results.figures_path}/entropy_test_vs_ood_boxplot.png") |
|
plt.savefig(f"{results.figures_path}/entropy_test_vs_ood_boxplot.svg") |
|
plt.xticks(rotation=45) |
|
plt.show() |
|
return [item.get_ydata()[1] for item in bx['whiskers']][2*test_idx+1] |
|
|
|
def plot_entropy_per_unique_length(results,split): |
|
seqs_len = results.splits_df_dict[f"{split}_df"]["RNA Sequences",'0'].str.len().values |
|
index = results.splits_df_dict[f"{split}_df"]["RNA Sequences",'0'].values |
|
entropies = results.splits_df_dict[f"{split}_df"]["Entropy","0"].values |
|
|
|
|
|
df = pd.DataFrame({"Entropy":entropies,"Sequences Length":seqs_len},index=index) |
|
|
|
|
|
fig = df.boxplot(by='Sequences Length') |
|
fig.get_figure().gca().set_title("") |
|
fig.get_figure().gca().set_xlabel(f"Sequences Length ({split})") |
|
fig.get_figure().gca().set_ylabel("Entropy") |
|
plt.show() |
|
if results.save_results: |
|
plt.savefig(f"{results.figures_path}/{split}_entropy_per_length_boxplot.png") |
|
|
|
def plot_outliers(results,test_whisker_UB): |
|
test_df = results.splits_df_dict["test_df"] |
|
test_ent = test_df["Entropy", "0"] |
|
|
|
outlier_seqs = test_df.iloc[(test_whisker_UB < test_ent).values]['RNA Sequences']['0'].values |
|
outlier_seqs_in_ad = list(set(results.dataset).intersection(set(outlier_seqs))) |
|
major_class_dict = results.dataset.loc[outlier_seqs_in_ad]['small_RNA_class_annotation'][~results.dataset['hico'].isnull()].value_counts() |
|
major_class_dict = {x:y for x,y in major_class_dict.items() if y!=0} |
|
plt.pie(major_class_dict.values(),labels=major_class_dict.keys(),autopct='%1.1f%%') |
|
plt.axis('equal') |
|
plt.show() |
|
plt.savefig(f"{results.figures_path}/Decomposition_outliers_in_test_pie.png") |
|
|
|
if results.save_results: |
|
save(data = outlier_seqs.tolist(),path=results.analysis_path+"/logit_outlier_seqs_ID.yaml") |
|
|
|
def log_model_params(model,analysis_path): |
|
model_params = {"Model Coef": model.coef_[0][0],\ |
|
"Model intercept": model.intercept_[0],\ |
|
"Threshold": -model.intercept_[0]/model.coef_[0][0]} |
|
model_params = eval(json.dumps(model_params)) |
|
save(data = model_params,path=analysis_path+"/logits_model_coef.yaml") |
|
|
|
model.threshold = model_params["Threshold"] |
|
|
|
def compute_entropy_per_split(results:Results_Handler): |
|
|
|
for split in results.splits: |
|
results.splits_df_dict[f"{split}_df"]["Entropy","0"] = entropy(results.splits_df_dict[f"{split}_df"]["Logits"].values,axis=1) |
|
|
|
def compute_novelty_prediction_per_split(results,model): |
|
|
|
for split in results.splits: |
|
results.splits_df_dict[f'{split}_df']['Novelty Prediction','is_known_class'] = results.splits_df_dict[f'{split}_df']['Entropy','0']<= model.threshold |
|
|
|
def compute_logits_clf_metrics(results): |
|
aucs_roc = [] |
|
aucs_prc = [] |
|
f1s_prc = [] |
|
replicates = 10 |
|
show_figure: bool = False |
|
for random_state in range(replicates): |
|
|
|
if random_state == replicates-1: |
|
show_figure = True |
|
|
|
test_labels,lr_probs,yhat,model = entropy_clf(results,random_state) |
|
|
|
if results.save_results: |
|
log_model_params(model,results.analysis_path) |
|
compute_novelty_prediction_per_split(results,model) |
|
|
|
auc_roc = compute_roc(test_labels,lr_probs,results,show_figure) |
|
f1_prc,auc_prc = compute_prc(test_labels,lr_probs,yhat,results,show_figure) |
|
aucs_roc.append(auc_roc) |
|
aucs_prc.append(auc_prc) |
|
f1s_prc.append(f1_prc) |
|
|
|
|
|
auc_roc_score = sum(aucs_roc)/len(aucs_roc) |
|
auc_roc_std = np.std(aucs_roc) |
|
auc_prc_score = sum(aucs_prc)/len(aucs_prc) |
|
auc_prc_std = np.std(aucs_prc) |
|
f1_prc_score = sum(f1s_prc)/len(f1s_prc) |
|
f1_prc_std = np.std(f1s_prc) |
|
|
|
logger.info(f"auc roc is {auc_roc_score} +- {auc_roc_std}") |
|
logger.info(f"auc prc is {auc_prc_score} +- {auc_prc_std}") |
|
logger.info(f"f1 prc is {f1_prc_score} +- {f1_prc_std}") |
|
|
|
logits_clf_metrics = {"AUC ROC score": auc_roc_score,\ |
|
"auc_roc_std": auc_roc_std,\ |
|
"AUC PRC score": auc_prc_score,\ |
|
"auc_prc_std":auc_prc_std,\ |
|
"F1 PRC score": f1_prc_score,\ |
|
"f1_prc_std":f1_prc_std |
|
} |
|
|
|
logits_clf_metrics = eval(json.dumps(logits_clf_metrics)) |
|
if results.save_results: |
|
save(data = logits_clf_metrics,path=results.analysis_path+"/logits_clf_metrics.yaml") |
|
|
|
def compute_entropies(embedds_path): |
|
logger.info("Computing entropy for ID vs OOD:") |
|
|
|
|
|
splits = ['train','valid','test','ood','artificial','no_annotation'] |
|
|
|
run_name = None |
|
|
|
|
|
results:Results_Handler = Results_Handler(embedds_path=embedds_path,splits=splits,read_dataset=True,save_results=True) |
|
|
|
results.append_loco_variants() |
|
|
|
results.splits[-1:-1] = ['artificial_affix','loco_not_in_train','loco_mixed','loco_in_train'] |
|
|
|
|
|
if results.mc_flag: |
|
results.splits.remove("ood") |
|
|
|
compute_entropy_per_split(results) |
|
|
|
results.splits.remove("train") |
|
results.splits.remove("valid") |
|
|
|
compute_logits_clf_metrics(results) |
|
|
|
test_whisker_UB = plot_entropy(results) |
|
logger.info("plotting entropy per unique length") |
|
plot_entropy_per_unique_length(results,'artificial_affix') |
|
logger.info('plotting entropy per unique length for ood') |
|
|
|
logger.info("plotting outliers") |
|
plot_outliers(results,test_whisker_UB) |
|
|
|
|
|
|
|
|
|
|