Update utils/group_classifier.py
Browse files
utils/group_classifier.py
CHANGED
@@ -19,7 +19,7 @@ _lab_dict = {
|
|
19 |
6: 'Women'}
|
20 |
|
21 |
@st.cache_resource
|
22 |
-
def
|
23 |
"""
|
24 |
loads the document classifier using haystack, where the name/path of model
|
25 |
in HF-hub as string is used to fetch the model object.Either configfile or
|
@@ -51,7 +51,7 @@ def load_targetClassifier(config_file:str = None, classifier_name:str = None):
|
|
51 |
|
52 |
|
53 |
@st.cache_data
|
54 |
-
def
|
55 |
threshold:float = 0.5,
|
56 |
classifier_model:pipeline= None
|
57 |
)->Tuple[DataFrame,Series]:
|
@@ -74,20 +74,20 @@ def target_classification(haystack_doc:pd.DataFrame,
|
|
74 |
x: Series object with the unique SDG covered in the document uploaded and
|
75 |
the number of times it is covered/discussed/count_of_paragraphs.
|
76 |
"""
|
77 |
-
logging.info("Working on
|
78 |
if not classifier_model:
|
79 |
-
classifier_model = st.session_state['
|
80 |
|
81 |
results = classifier_model(list(haystack_doc.text))
|
82 |
labels_= [(l[0]['label'],
|
83 |
l[0]['score']) for l in results]
|
84 |
|
85 |
|
86 |
-
df1 = DataFrame(labels_, columns=["
|
87 |
df = pd.concat([haystack_doc,df1],axis=1)
|
88 |
|
89 |
df = df.sort_values(by="Relevancy", ascending=False).reset_index(drop=True)
|
90 |
df.index += 1
|
91 |
-
df['Label_def'] = df['
|
92 |
|
93 |
return df
|
|
|
19 |
6: 'Women'}
|
20 |
|
21 |
@st.cache_resource
|
22 |
+
def load_groupClassifier(config_file:str = None, classifier_name:str = None):
|
23 |
"""
|
24 |
loads the document classifier using haystack, where the name/path of model
|
25 |
in HF-hub as string is used to fetch the model object.Either configfile or
|
|
|
51 |
|
52 |
|
53 |
@st.cache_data
|
54 |
+
def group_classification(haystack_doc:pd.DataFrame,
|
55 |
threshold:float = 0.5,
|
56 |
classifier_model:pipeline= None
|
57 |
)->Tuple[DataFrame,Series]:
|
|
|
74 |
x: Series object with the unique SDG covered in the document uploaded and
|
75 |
the number of times it is covered/discussed/count_of_paragraphs.
|
76 |
"""
|
77 |
+
logging.info("Working on Group Extraction")
|
78 |
if not classifier_model:
|
79 |
+
classifier_model = st.session_state['group_classifier']
|
80 |
|
81 |
results = classifier_model(list(haystack_doc.text))
|
82 |
labels_= [(l[0]['label'],
|
83 |
l[0]['score']) for l in results]
|
84 |
|
85 |
|
86 |
+
df1 = DataFrame(labels_, columns=["Group Label","Relevancy"])
|
87 |
df = pd.concat([haystack_doc,df1],axis=1)
|
88 |
|
89 |
df = df.sort_values(by="Relevancy", ascending=False).reset_index(drop=True)
|
90 |
df.index += 1
|
91 |
+
df['Label_def'] = df['Group Label'].apply(lambda i: _lab_dict[i])
|
92 |
|
93 |
return df
|