Update utils/sdg_classifier.py
Browse files- utils/sdg_classifier.py +31 -31
utils/sdg_classifier.py
CHANGED
@@ -14,27 +14,27 @@ except ImportError:
|
|
14 |
logging.info("Streamlit not installed")
|
15 |
|
16 |
## Labels dictionary ###
|
17 |
-
_lab_dict = {0: '
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
|
36 |
@st.cache(allow_output_mutation=True)
|
37 |
-
def
|
38 |
"""
|
39 |
loads the document classifier using haystack, where the name/path of model
|
40 |
in HF-hub as string is used to fetch the model object.Either configfile or
|
@@ -57,7 +57,7 @@ def load_sdgClassifier(config_file:str = None, classifier_name:str = None):
|
|
57 |
return
|
58 |
else:
|
59 |
config = getconfig(config_file)
|
60 |
-
classifier_name = config.get('
|
61 |
|
62 |
logging.info("Loading classifier")
|
63 |
doc_classifier = TransformersDocumentClassifier(
|
@@ -68,7 +68,7 @@ def load_sdgClassifier(config_file:str = None, classifier_name:str = None):
|
|
68 |
|
69 |
|
70 |
@st.cache(allow_output_mutation=True)
|
71 |
-
def
|
72 |
threshold:float = 0.8,
|
73 |
classifier_model:TransformersDocumentClassifier= None
|
74 |
)->Tuple[DataFrame,Series]:
|
@@ -95,10 +95,10 @@ def sdg_classification(haystack_doc:List[Document],
|
|
95 |
the number of times it is covered/discussed/count_of_paragraphs.
|
96 |
|
97 |
"""
|
98 |
-
logging.info("Working on
|
99 |
if not classifier_model:
|
100 |
if check_streamlit():
|
101 |
-
classifier_model = st.session_state['
|
102 |
else:
|
103 |
logging.warning("No streamlit envinornment found, Pass the classifier")
|
104 |
return
|
@@ -109,23 +109,23 @@ def sdg_classification(haystack_doc:List[Document],
|
|
109 |
labels_= [(l.meta['classification']['label'],
|
110 |
l.meta['classification']['score'],l.content,) for l in results]
|
111 |
|
112 |
-
df = DataFrame(labels_, columns=["
|
113 |
|
114 |
df = df.sort_values(by="Relevancy", ascending=False).reset_index(drop=True)
|
115 |
df.index += 1
|
116 |
df =df[df['Relevancy']>threshold]
|
117 |
|
118 |
# creating the dataframe for value counts of SDG, along with 'title' of SDGs
|
119 |
-
x = df['
|
120 |
x = x.rename('count')
|
121 |
-
x = x.rename_axis('
|
122 |
-
x["
|
123 |
x = x.sort_values(by=['count'], ascending=False)
|
124 |
-
x['SDG_name'] = x['
|
125 |
-
x['SDG_Num'] = x['
|
126 |
|
127 |
-
df['
|
128 |
-
df = df.sort_values('
|
129 |
|
130 |
return df, x
|
131 |
|
|
|
14 |
logging.info("Streamlit not installed")
|
15 |
|
16 |
## Labels dictionary ###
|
17 |
+
_lab_dict = {0: 'Agricultural communities',
|
18 |
+
1: 'Children',
|
19 |
+
2: 'Coastal communities',
|
20 |
+
3: 'Ethnic, racial or other minorities',
|
21 |
+
4: 'Fishery communities',
|
22 |
+
5: 'Informal sector workers',
|
23 |
+
6: 'Members of indigenous and local communities',
|
24 |
+
7: 'Migrants and displaced persons',
|
25 |
+
8: 'Older persons',
|
26 |
+
9: 'Other',
|
27 |
+
10: 'Persons living in poverty',
|
28 |
+
11: 'Persons with disabilities',
|
29 |
+
12: 'Persons with pre-existing health conditions',
|
30 |
+
13: 'Residents of drought-prone regions',
|
31 |
+
14: 'Rural populations',
|
32 |
+
15: 'Sexual minorities (LGBTQI+)',
|
33 |
+
16: 'Urban populations',
|
34 |
+
17: 'Women and other genders'}
|
35 |
|
36 |
@st.cache(allow_output_mutation=True)
|
37 |
+
def load_Classifier(config_file:str = None, classifier_name:str = None):
|
38 |
"""
|
39 |
loads the document classifier using haystack, where the name/path of model
|
40 |
in HF-hub as string is used to fetch the model object.Either configfile or
|
|
|
57 |
return
|
58 |
else:
|
59 |
config = getconfig(config_file)
|
60 |
+
classifier_name = config.get('vulnerability','MODEL')
|
61 |
|
62 |
logging.info("Loading classifier")
|
63 |
doc_classifier = TransformersDocumentClassifier(
|
|
|
68 |
|
69 |
|
70 |
@st.cache(allow_output_mutation=True)
|
71 |
+
def classification(haystack_doc:List[Document],
|
72 |
threshold:float = 0.8,
|
73 |
classifier_model:TransformersDocumentClassifier= None
|
74 |
)->Tuple[DataFrame,Series]:
|
|
|
95 |
the number of times it is covered/discussed/count_of_paragraphs.
|
96 |
|
97 |
"""
|
98 |
+
logging.info("Working on Vulnerability Classification")
|
99 |
if not classifier_model:
|
100 |
if check_streamlit():
|
101 |
+
classifier_model = st.session_state['vulnerability_classifier']
|
102 |
else:
|
103 |
logging.warning("No streamlit envinornment found, Pass the classifier")
|
104 |
return
|
|
|
109 |
labels_= [(l.meta['classification']['label'],
|
110 |
l.meta['classification']['score'],l.content,) for l in results]
|
111 |
|
112 |
+
df = DataFrame(labels_, columns=["Vulnerability","Relevancy","text"])
|
113 |
|
114 |
df = df.sort_values(by="Relevancy", ascending=False).reset_index(drop=True)
|
115 |
df.index += 1
|
116 |
df =df[df['Relevancy']>threshold]
|
117 |
|
118 |
# creating the dataframe for value counts of SDG, along with 'title' of SDGs
|
119 |
+
x = df['Vulnerability'].value_counts()
|
120 |
x = x.rename('count')
|
121 |
+
x = x.rename_axis('Vulnerability').reset_index()
|
122 |
+
x["Vulnerability"] = pd.to_numeric(x["Vulnerability"])
|
123 |
x = x.sort_values(by=['count'], ascending=False)
|
124 |
+
x['SDG_name'] = x['Vulnerability'].apply(lambda x: _lab_dict[x])
|
125 |
+
x['SDG_Num'] = x['Vulnerability'].apply(lambda x: "Vulnerability "+str(x))
|
126 |
|
127 |
+
df['Vulnerability'] = pd.to_numeric(df['Vulnerability'])
|
128 |
+
df = df.sort_values('Vulnerability')
|
129 |
|
130 |
return df, x
|
131 |
|