Update appStore/vulnerability_analysis.py
Browse files
appStore/vulnerability_analysis.py
CHANGED
@@ -8,7 +8,7 @@ import matplotlib.pyplot as plt
|
|
8 |
import numpy as np
|
9 |
import pandas as pd
|
10 |
import streamlit as st
|
11 |
-
from utils.
|
12 |
import logging
|
13 |
logger = logging.getLogger(__name__)
|
14 |
from utils.config import get_classifier_params
|
@@ -59,7 +59,7 @@ def app():
|
|
59 |
|
60 |
if 'key1' in st.session_state:
|
61 |
df = st.session_state.key1
|
62 |
-
classifier =
|
63 |
st.session_state['{}_classifier'.format(classifier_identifier)] = classifier
|
64 |
|
65 |
if sum(df['Target Label'] == 'TARGET') > 100:
|
@@ -67,7 +67,7 @@ def app():
|
|
67 |
else:
|
68 |
warning_msg = ""
|
69 |
|
70 |
-
df =
|
71 |
threshold= params['threshold'])
|
72 |
|
73 |
st.session_state.key1 = df
|
|
|
8 |
import numpy as np
|
9 |
import pandas as pd
|
10 |
import streamlit as st
|
11 |
+
from utils.vulnerability_classifier import load_vulnerabilityClassifier, vulnerability_classification
|
12 |
import logging
|
13 |
logger = logging.getLogger(__name__)
|
14 |
from utils.config import get_classifier_params
|
|
|
59 |
|
60 |
if 'key1' in st.session_state:
|
61 |
df = st.session_state.key1
|
62 |
+
classifier = load_vulnerabilityClassifier(classifier_name=params['model_name'])
|
63 |
st.session_state['{}_classifier'.format(classifier_identifier)] = classifier
|
64 |
|
65 |
if sum(df['Target Label'] == 'TARGET') > 100:
|
|
|
67 |
else:
|
68 |
warning_msg = ""
|
69 |
|
70 |
+
df = vulnerability_classification(haystack_doc=df,
|
71 |
threshold= params['threshold'])
|
72 |
|
73 |
st.session_state.key1 = df
|