from typing import List, Tuple from typing_extensions import Literal import logging import pandas as pd from pandas import DataFrame, Series from utils.config import getconfig from utils.preprocessing import processingpipeline import streamlit as st from transformers import pipeline ## Labels dictionary ### label_dict = { '0':'NO', '1':'YES', } def get_target_labels(preds): """ Function that takes the numerical predictions as an input and returns a list of the labels. """ # Get label names preds_list = preds.tolist() st.write('preds_list') st.write(preds_list) predictions_names=[] # loop through each prediction for ele in preds_list: # see if there is a value 1 and retrieve index try: index_of_one = ele.index(1) except ValueError: index_of_one = "NA" # Retrieve the name of the label (if no prediction made = NA) if index_of_one != "NA": name = label_dict[index_of_one] else: name = "Other" # Append name to list predictions_names.append(name) return predictions_names @st.cache_resource def load_targetClassifier(config_file:str = None, classifier_name:str = None): """ loads the document classifier using haystack, where the name/path of model in HF-hub as string is used to fetch the model object.Either configfile or model should be passed. 1. https://docs.haystack.deepset.ai/reference/document-classifier-api 2. https://docs.haystack.deepset.ai/docs/document_classifier Params -------- config_file: config file path from which to read the model name classifier_name: if modelname is passed, it takes a priority if not \ found then will look for configfile, else raise error. Return: document classifier model """ if not classifier_name: if not config_file: logging.warning("Pass either model name or config file") return else: config = getconfig(config_file) classifier_name = config.get('target','MODEL') logging.info("Loading classifier") doc_classifier = pipeline("text-classification", model=classifier_name, top_k =1) return doc_classifier @st.cache_data def target_classification(haystack_doc:pd.DataFrame, threshold:float = 0.5, classifier_model:pipeline= None )->Tuple[DataFrame,Series]: """ Text-Classification on the list of texts provided. Classifier provides the most appropriate label for each text. There labels indicate whether the paragraph references a specific action, target or measure in the paragraph. --------- haystack_doc: List of haystack Documents. The output of Preprocessing Pipeline contains the list of paragraphs in different format,here the list of Haystack Documents is used. threshold: threshold value for the model to keep the results from classifier classifiermodel: you can pass the classifier model directly,which takes priority however if not then looks for model in streamlit session. In case of streamlit avoid passing the model directly. Returns ---------- df: Dataframe with two columns['SDG:int', 'text'] x: Series object with the unique SDG covered in the document uploaded and the number of times it is covered/discussed/count_of_paragraphs. """ logging.info("Working on target/action identification") haystack_doc['Target Label'] = 'NA' if not target_classifier_model: target_classifier_model = st.session_state['target_classifier'] # Get predictions predictions = target_classifier_model(list(haystack_doc.text)) st.write('predictions') st.write(predictions[:10]) # Get labels for predictions pred_labels = get_target_labels(predictions) st.write('pred_labels') st.write(pred_labels[:10]) # Save labels haystack_doc['Target Label'] = pred_labels return haystack_doc # logging.info("Working on action/target extraction") # if not classifier_model: # # classifier_model = st.session_state['target_classifier'] # # results = classifier_model(list(haystack_doc.text)) # # labels_= [(l[0]['label'], # # l[0]['score']) for l in results] # # df1 = DataFrame(labels_, columns=["Target Label","Target Score"]) # # df = pd.concat([haystack_doc,df1],axis=1) # # df = df.sort_values(by="Target Score", ascending=False).reset_index(drop=True) # # df['Target Score'] = df['Target Score'].round(2) # # df.index += 1 # # # df['Label_def'] = df['Target Label'].apply(lambda i: _lab_dict[i])