|
from typing import List, Tuple |
|
from typing_extensions import Literal |
|
import logging |
|
import pandas as pd |
|
from pandas import DataFrame, Series |
|
from utils.config import getconfig |
|
from utils.preprocessing import processingpipeline |
|
import streamlit as st |
|
from setfit import SetFitModel |
|
from transformers import pipeline |
|
|
|
|
|
label_dict = { |
|
'0':'NO', |
|
'1':'YES', |
|
} |
|
|
|
def get_target_labels(preds): |
|
|
|
""" |
|
Function that takes the numerical predictions as an input and returns a list of the labels. |
|
|
|
""" |
|
|
|
|
|
preds_list = preds.numpy().tolist() |
|
|
|
|
|
|
|
predictions_names=[] |
|
|
|
|
|
for ele in preds_list: |
|
|
|
|
|
try: |
|
index_of_one = ele.index(1) |
|
except ValueError: |
|
index_of_one = "NA" |
|
|
|
|
|
if index_of_one != "NA": |
|
name = label_dict[index_of_one] |
|
else: |
|
name = "Other" |
|
|
|
|
|
predictions_names.append(name) |
|
|
|
return predictions_names |
|
|
|
@st.cache_resource |
|
def load_targetClassifier(config_file:str = None, classifier_name:str = None): |
|
""" |
|
loads the document classifier using haystack, where the name/path of model |
|
in HF-hub as string is used to fetch the model object.Either configfile or |
|
model should be passed. |
|
1. https://docs.haystack.deepset.ai/reference/document-classifier-api |
|
2. https://docs.haystack.deepset.ai/docs/document_classifier |
|
Params |
|
-------- |
|
config_file: config file path from which to read the model name |
|
classifier_name: if modelname is passed, it takes a priority if not \ |
|
found then will look for configfile, else raise error. |
|
Return: document classifier model |
|
""" |
|
if not classifier_name: |
|
if not config_file: |
|
logging.warning("Pass either model name or config file") |
|
return |
|
else: |
|
config = getconfig(config_file) |
|
classifier_name = config.get('target','MODEL') |
|
|
|
logging.info("Loading classifier") |
|
|
|
|
|
doc_classifier = SetFitModel.from_pretrained("leavoigt/vulnerability_target") |
|
|
|
return doc_classifier |
|
|
|
|
|
@st.cache_data |
|
def target_classification(haystack_doc:pd.DataFrame, |
|
threshold:float = 0.5, |
|
classifier_model:pipeline= None |
|
)->Tuple[DataFrame,Series]: |
|
""" |
|
Text-Classification on the list of texts provided. Classifier provides the |
|
most appropriate label for each text. There labels indicate whether the paragraph |
|
references a specific action, target or measure in the paragraph. |
|
--------- |
|
haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline |
|
contains the list of paragraphs in different format,here the list of |
|
Haystack Documents is used. |
|
threshold: threshold value for the model to keep the results from classifier |
|
classifiermodel: you can pass the classifier model directly,which takes priority |
|
however if not then looks for model in streamlit session. |
|
In case of streamlit avoid passing the model directly. |
|
Returns |
|
---------- |
|
df: Dataframe with two columns['SDG:int', 'text'] |
|
x: Series object with the unique SDG covered in the document uploaded and |
|
the number of times it is covered/discussed/count_of_paragraphs. |
|
""" |
|
|
|
logging.info("Working on target/action identification") |
|
|
|
haystack_doc['Target Label'] = 'NA' |
|
st.write("haystack_doc") |
|
st.write(haystack_doc) |
|
|
|
if not classifier_model: |
|
|
|
st.write("No classifier_model") |
|
|
|
classifier_model = st.session_state['target_classifier'] |
|
st.write("classifier model defined") |
|
|
|
|
|
predictions = classifier_model(list(haystack_doc.text)) |
|
st.write("predictions made") |
|
st.write(predictions) |
|
|
|
pred_labels = get_target_labels(predictions) |
|
st.write("pred_labels") |
|
st.write(pred_labels) |
|
|
|
haystack_doc['Target Label'] = pred_labels |
|
|
|
return haystack_doc |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|