File size: 5,527 Bytes
079c7c0 ef9edc5 079c7c0 40953ed 079c7c0 cbf120d 079c7c0 d30cd5d 41aa9dd d30cd5d 154ee8f ef9edc5 079c7c0 ea878f4 ef9edc5 079c7c0 d30cd5d 079c7c0 d6606ef 4bf4856 e692c88 d6606ef 897a5b5 4a76f03 4caba7c 4a76f03 897a5b5 4a76f03 897a5b5 b5de5f6 d6606ef 4a76f03 ea878f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# set path
import glob, os, sys;
sys.path.append('../utils')
#import needed libraries
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st
from utils.vulnerability_classifier import load_vulnerabilityClassifier, vulnerability_classification
import logging
logger = logging.getLogger(__name__)
from utils.config import get_classifier_params
from utils.preprocessing import paraLengthCheck
from io import BytesIO
import xlsxwriter
import plotly.express as px
from utils.vulnerability_classifier import label_dict
# Declare all the necessary variables
classifier_identifier = 'vulnerability'
params = get_classifier_params(classifier_identifier)
@st.cache_data
def to_excel(df,sectorlist):
len_df = len(df)
output = BytesIO()
writer = pd.ExcelWriter(output, engine='xlsxwriter')
df.to_excel(writer, index=False, sheet_name='Sheet1')
workbook = writer.book
worksheet = writer.sheets['Sheet1']
worksheet.data_validation('S2:S{}'.format(len_df),
{'validate': 'list',
'source': ['No', 'Yes', 'Discard']})
worksheet.data_validation('X2:X{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('T2:T{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('U2:U{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('V2:V{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('W2:U{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
writer.save()
processed_data = output.getvalue()
return processed_data
def app():
### Main app code ###
with st.container():
# If a document has been processed
if 'key0' in st.session_state:
# Run vulnerability classifier
df = st.session_state.key0
classifier = load_vulnerabilityClassifier(classifier_name=params['model_name'])
st.session_state['{}_classifier'.format(classifier_identifier)] = classifier
# Get the predictions
df = vulnerability_classification(haystack_doc=df,
threshold= params['threshold'])
# Store df in session state with key1
st.session_state.key1 = df
def vulnerability_display():
# Assign dataframe a name
df_vul = st.session_state['key0']
st.write(df_vul)
#st.write(df_vul)
col1, col2 = st.columns([1,1])
with col1:
# Header
st.subheader("Explore references to vulnerable groups:")
# Text
num_paragraphs = len(df_vul['Vulnerability Label'])
num_references = df_vul['Vulnerability Label'].apply(lambda x: 'Other' not in x).sum()
st.markdown(f"""<div style="text-align: justify;"> The document contains a
total of <span style="color: red;">{num_paragraphs}</span> paragraphs.
We identified <span style="color: red;">{num_references}</span>
references to vulnerable groups.</div>
<br>
In the pie chart on the right you can see the distribution of the different
groups defined. For a more detailed view in the text, see the paragraphs and
their respective labels in the table below.</div>""", unsafe_allow_html=True)
with col2:
### Bar chart
# # Create a df that stores all the labels
df_labels = pd.DataFrame(list(label_dict.items()), columns=['Label ID', 'Label'])
# Count how often each label appears in the "Vulnerability Labels" column
group_counts = {}
# Iterate through each sublist
for index, row in df_vul.iterrows():
# Iterate through each group in the sublist
for sublist in row['Vulnerability Label']:
# Update the count in the dictionary
group_counts[sublist] = group_counts.get(sublist, 0) + 1
# Create a new dataframe from group_counts
df_label_count = pd.DataFrame(list(group_counts.items()), columns=['Label', 'Count'])
# Merge the label counts with the df_label DataFrame
df_label_count = df_labels.merge(df_label_count, on='Label', how='left')
# Exclude the "Other" group
df_bar_chart = df_label_count[df_label_count['Label'] != 'Other']
# Bar chart
fig = px.bar(df_bar_chart,
x='Label',
y='Count',
title='How many references have been found to each group?',
labels={'Count': 'Frequency'})
#Show plot
st.plotly_chart(fig, use_container_width=True)
# ### Table
st.write(df_vul[df_vul['Vulnerability Label'].apply(lambda x: 'Other' not in x)])
|