File size: 2,020 Bytes
0b11a42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

import matplotlib.pyplot as plt
from matplotlib import pyplot
from sklearn.metrics import (auc, f1_score, precision_recall_curve,
                             roc_auc_score, roc_curve)

from ..utils.tcga_post_analysis_utils import Results_Handler


def compute_prc(test_labels,lr_probs,yhat,results:Results_Handler,show_figure:bool=False):

    lr_precision, lr_recall, _ = precision_recall_curve(test_labels, lr_probs)
    lr_f1, lr_auc = f1_score(test_labels, yhat), auc(lr_recall, lr_precision)
    # plot the precision-recall curves
    if show_figure:
        pyplot.plot(lr_recall, lr_precision, marker='.', label=results.figures_path.split('/')[-2])
        # axis labels
        pyplot.xlabel('Recall')
        pyplot.ylabel('Precision')
        # show the legend
        pyplot.legend()
    # save and show the plot
    plt.title("PRC Curve")

    if results.save_results:
        plt.savefig(f"{results.figures_path}/prc_curve.png")
        plt.savefig(f"{results.figures_path}/prc_curve.svg")

    if show_figure:
        plt.show()
    return lr_f1,lr_auc

def compute_roc(test_labels,lr_probs,results,show_figure:bool=False):
    
    ns_probs = [0 for _ in range(len(test_labels))]

    # calculate scores
    ns_auc = roc_auc_score(test_labels, ns_probs)
    lr_auc = roc_auc_score(test_labels, lr_probs)
    # calculate roc curves
    ns_fpr, ns_tpr, _ = roc_curve(test_labels, ns_probs)
    lr_fpr, lr_tpr, _ = roc_curve(test_labels, lr_probs)

    # plot the roc curve for the model
    if show_figure:
        plt.plot(lr_fpr, lr_tpr, marker='.',markersize=1, label=results.figures_path.split('/')[-2])
        # axis labels
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        # show the legend
        plt.legend()
        plt.title("ROC Curve")

    if results.save_results:
        plt.savefig(f"{results.figures_path}/roc_curve.png")
        plt.savefig(f"{results.figures_path}/roc_curve.svg")

    if show_figure:    
        plt.show()
    return lr_auc