# set path import glob, os, sys; sys.path.append('../utils') #import needed libraries import seaborn as sns import matplotlib.pyplot as plt import numpy as np import pandas as pd import streamlit as st from utils.vulnerability_classifier import load_vulnerabilityClassifier, vulnerability_classification import logging logger = logging.getLogger(__name__) from utils.config import get_classifier_params from utils.preprocessing import paraLengthCheck from io import BytesIO import xlsxwriter import plotly.express as px import plotly.graph_objects as go from utils.vulnerability_classifier import label_dict # Declare all the necessary variables classifier_identifier = 'vulnerability' params = get_classifier_params(classifier_identifier) @st.cache_data def to_excel(df,sectorlist): len_df = len(df) output = BytesIO() writer = pd.ExcelWriter(output, engine='xlsxwriter') df.to_excel(writer, index=False, sheet_name='Sheet1') workbook = writer.book worksheet = writer.sheets['Sheet1'] worksheet.data_validation('S2:S{}'.format(len_df), {'validate': 'list', 'source': ['No', 'Yes', 'Discard']}) worksheet.data_validation('X2:X{}'.format(len_df), {'validate': 'list', 'source': sectorlist + ['Blank']}) worksheet.data_validation('T2:T{}'.format(len_df), {'validate': 'list', 'source': sectorlist + ['Blank']}) worksheet.data_validation('U2:U{}'.format(len_df), {'validate': 'list', 'source': sectorlist + ['Blank']}) worksheet.data_validation('V2:V{}'.format(len_df), {'validate': 'list', 'source': sectorlist + ['Blank']}) worksheet.data_validation('W2:U{}'.format(len_df), {'validate': 'list', 'source': sectorlist + ['Blank']}) writer.save() processed_data = output.getvalue() return processed_data def app(): ### Main app code ### with st.container(): # If a document has been processed if 'key0' in st.session_state: # Run vulnerability classifier df = st.session_state.key0 classifier = load_vulnerabilityClassifier(classifier_name=params['model_name']) st.session_state['{}_classifier'.format(classifier_identifier)] = classifier # Get the predictions df = vulnerability_classification(haystack_doc=df, threshold= params['threshold']) # Store df in session state with key1 st.session_state.key1 = df def vulnerability_display(): # Get the vulnerability df df = st.session_state['key1'] # Filter the dataframe to only show the paragraphs with references df_filtered = df[df['Vulnerability Label'].apply(lambda x: len(x) > 0 and 'Other' not in x)] # Rename column df_filtered.rename(columns={'Vulnerability Label': 'Group(s)'}, inplace=True) # Header st.subheader("Explore references to vulnerable groups:") # Text num_paragraphs = len(df['Vulnerability Label']) num_references = len(df_filtered['Group(s)']) st.markdown(f"""<div style="text-align: justify;">The document contains a total of <span style="color: red;">{num_paragraphs}</span> paragraphs. We identified <span style="color: red;">{num_references}</span> references to groups in vulnerable situations.</div> <br> <div style="text-align: justify;">We are searching for references related to the following groups: (1) Agricultural communities, (2) Children, (3) Ethnic, racial and other minorities, (4) Fishery communities, (5) Informal sector workers, (6) Members of indigenous and local communities, (7) Migrants and displaced persons, (8) Older persons, (9) Persons living in poverty, (10) Persons living with disabilities, (11) Persons with pre-existing health conditions, (12) Residents of drought-prone regions, (13) Rural populations, (14) Sexual minorities (LGBTQI+), (15) Urban populations, (16) Women and other genders.</div> <br> <div style="text-align: justify;">The chart below shows the groups for which references were found and the number of references identified. For a more detailed view in the text, see the paragraphs and their respective labels in the table underneath.</div>""", unsafe_allow_html=True) ### Bar chart # # Create a df that stores all the labels df_labels = pd.DataFrame(list(label_dict.items()), columns=['Label ID', 'Label']) # Count how often each label appears in the "Group identified" column group_counts = {} # Iterate through each sublist for index, row in df_filtered.iterrows(): # Iterate through each group in the sublist for sublist in row['Group(s)']: # Update the count in the dictionary group_counts[sublist] = group_counts.get(sublist, 0) + 1 # Create a new dataframe from group_counts df_label_count = pd.DataFrame(list(group_counts.items()), columns=['Label', 'Count']) # Merge the label counts with the df_label DataFrame df_label_count = df_labels.merge(df_label_count, on='Label', how='left') # Exclude the "Other" group and all groups that do not have a label df_bar_chart = df_label_count[df_label_count['Label'] != 'Other'] df_bar_chart = df_bar_chart.dropna(subset=['Count']) # Bar chart fig = go.Figure() fig.add_trace(go.Bar( y=df_bar_chart.Label, x=df_bar_chart.Count, orientation='h', marker=dict(color='purple'), )) # Customize layout fig.update_layout( title='Number of references identified', xaxis_title='Number of references', yaxis_title='Group', ) # Show the plot #fig.show() #Show plot st.plotly_chart(fig, use_container_width=True)