leavoigt commited on
Commit
5cb79de
·
1 Parent(s): e10deaa

Update utils/target_classifier.py

Browse files
Files changed (1) hide show
  1. utils/target_classifier.py +29 -10
utils/target_classifier.py CHANGED
@@ -69,21 +69,40 @@ def target_classification(haystack_doc:pd.DataFrame,
69
  x: Series object with the unique SDG covered in the document uploaded and
70
  the number of times it is covered/discussed/count_of_paragraphs.
71
  """
72
- logging.info("Working on action/target extraction")
 
 
 
 
73
  if not classifier_model:
 
74
  classifier_model = st.session_state['target_classifier']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- results = classifier_model(list(haystack_doc.text))
77
- labels_= [(l[0]['label'],
78
- l[0]['score']) for l in results]
79
 
80
 
81
- df1 = DataFrame(labels_, columns=["Target Label","Target Score"])
82
- df = pd.concat([haystack_doc,df1],axis=1)
83
 
84
- df = df.sort_values(by="Target Score", ascending=False).reset_index(drop=True)
85
- df['Target Score'] = df['Target Score'].round(2)
86
- df.index += 1
87
- # df['Label_def'] = df['Target Label'].apply(lambda i: _lab_dict[i])
88
 
89
  return df
 
69
  x: Series object with the unique SDG covered in the document uploaded and
70
  the number of times it is covered/discussed/count_of_paragraphs.
71
  """
72
+
73
+ logging.info("Working on target/action identification")
74
+
75
+ haystack_doc['Vulnerability Label'] = 'NA'
76
+
77
  if not classifier_model:
78
+
79
  classifier_model = st.session_state['target_classifier']
80
+
81
+ # Get predictions
82
+ predictions = classifier_model(list(haystack_doc.text))
83
+
84
+ # Get labels for predictions
85
+ pred_labels = getlabels(predictions)
86
+
87
+ # Save labels
88
+ haystack_doc['Target Label'] = pred_labels
89
+
90
+
91
+ # logging.info("Working on action/target extraction")
92
+ # if not classifier_model:
93
+ # classifier_model = st.session_state['target_classifier']
94
 
95
+ # results = classifier_model(list(haystack_doc.text))
96
+ # labels_= [(l[0]['label'],
97
+ # l[0]['score']) for l in results]
98
 
99
 
100
+ # df1 = DataFrame(labels_, columns=["Target Label","Target Score"])
101
+ # df = pd.concat([haystack_doc,df1],axis=1)
102
 
103
+ # df = df.sort_values(by="Target Score", ascending=False).reset_index(drop=True)
104
+ # df['Target Score'] = df['Target Score'].round(2)
105
+ # df.index += 1
106
+ # # df['Label_def'] = df['Target Label'].apply(lambda i: _lab_dict[i])
107
 
108
  return df