DarrenChensformer commited on
Commit
e3694cd
·
1 Parent(s): 71fa3d3

Add new module

Browse files
Files changed (1) hide show
  1. eval_keyphrase.py +45 -3
eval_keyphrase.py CHANGED
@@ -13,6 +13,8 @@
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
 
 
16
  import evaluate
17
  import datasets
18
 
@@ -86,10 +88,50 @@ class eval_keyphrase(evaluate.Metric):
86
  # TODO: Download external resources if needed
87
  pass
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def _compute(self, predictions, references):
90
  """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  return {
94
- "accuracy": accuracy,
 
 
 
 
95
  }
 
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
16
+ import string
17
+
18
  import evaluate
19
  import datasets
20
 
 
88
  # TODO: Download external resources if needed
89
  pass
90
 
91
+ def _normalize_keyphrase(self, kp):
92
+
93
+ def white_space_fix(text):
94
+ return ' '.join(text.split())
95
+
96
+ def remove_punc(text):
97
+ exclude = set(string.punctuation)
98
+ return ''.join(ch for ch in text if ch not in exclude)
99
+
100
+ def lower(text):
101
+ return text.lower()
102
+
103
+ return white_space_fix(remove_punc(lower(kp)))
104
+
105
  def _compute(self, predictions, references):
106
  """Returns the scores"""
107
+
108
+ macro_metrics = {'precision': [], 'recall': [], 'f1': [], 'num_pred': [], 'num_gold': []}
109
+
110
+ for targets, preds in zip(references, predictions):
111
+ targets = [self._normalize_keyphrase(tmp_key).strip() for tmp_key in targets if len(self._normalize_keyphrase(tmp_key).strip()) != 0]
112
+ preds = [self._normalize_keyphrase(tmp_key).strip() for tmp_key in preds if len(self._normalize_keyphrase(tmp_key).strip()) != 0]
113
+
114
+ total_tgt_set = set(targets)
115
+ total_preds = set(preds)
116
+ if len(total_tgt_set) == 0: continue
117
+
118
+ # get the total_correctly_matched indicators
119
+ total_correctly_matched = len(total_preds & total_tgt_set)
120
+
121
+ # macro metric calculating
122
+ precision = total_correctly_matched / len(total_preds) if len(total_preds) else 0.0
123
+ recall = total_correctly_matched / len(total_tgt_set)
124
+ f1 = 2 * precision * recall / (precision + recall) if total_correctly_matched > 0 else 0.0
125
+ macro_metrics['precision'].append(precision)
126
+ macro_metrics['recall'].append(recall)
127
+ macro_metrics['f1'].append(f1)
128
+ macro_metrics['num_pred'].append(len(total_preds))
129
+ macro_metrics['num_gold'].append(len(total_tgt_set))
130
+
131
  return {
132
+ "precision": round(sum(macro_metrics["precision"])/len(macro_metrics["precision"]), 4),
133
+ "recall": round(sum(macro_metrics["recall"])/len(macro_metrics["recall"]), 4),
134
+ "f1": round(sum(macro_metrics["f1"])/len(macro_metrics["f1"]), 4),
135
+ "num_pred": round(sum(macro_metrics["num_pred"])/len(macro_metrics["num_pred"]), 4),
136
+ "num_gold": round(sum(macro_metrics["num_gold"])/len(macro_metrics["num_gold"]), 4),
137
  }