File size: 4,631 Bytes
fcd0a70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from collections import defaultdict
import numpy as np
import torch
from seqeval.metrics.v1 import _prf_divide
def extract_tp_actual_correct(y_true, y_pred):
entities_true = defaultdict(set)
entities_pred = defaultdict(set)
for type_name, (start, end), idx in y_true:
entities_true[type_name].add((start, end, idx))
for type_name, (start, end), idx in y_pred:
entities_pred[type_name].add((start, end, idx))
target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
tp_sum = np.array([], dtype=np.int32)
pred_sum = np.array([], dtype=np.int32)
true_sum = np.array([], dtype=np.int32)
for type_name in target_names:
entities_true_type = entities_true.get(type_name, set())
entities_pred_type = entities_pred.get(type_name, set())
tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type))
pred_sum = np.append(pred_sum, len(entities_pred_type))
true_sum = np.append(true_sum, len(entities_true_type))
return pred_sum, tp_sum, true_sum, target_names
def flatten_for_eval(y_true, y_pred):
all_true = []
all_pred = []
for i, (true, pred) in enumerate(zip(y_true, y_pred)):
all_true.extend([t + [i] for t in true])
all_pred.extend([p + [i] for p in pred])
return all_true, all_pred
def compute_prf(y_true, y_pred, average='micro'):
y_true, y_pred = flatten_for_eval(y_true, y_pred)
pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred)
if average == 'micro':
tp_sum = np.array([tp_sum.sum()])
pred_sum = np.array([pred_sum.sum()])
true_sum = np.array([true_sum.sum()])
precision = _prf_divide(
numerator=tp_sum,
denominator=pred_sum,
metric='precision',
modifier='predicted',
average=average,
warn_for=('precision', 'recall', 'f-score'),
zero_division='warn'
)
recall = _prf_divide(
numerator=tp_sum,
denominator=true_sum,
metric='recall',
modifier='true',
average=average,
warn_for=('precision', 'recall', 'f-score'),
zero_division='warn'
)
denominator = precision + recall
denominator[denominator == 0.] = 1
f_score = 2 * (precision * recall) / denominator
return {'precision': precision[0], 'recall': recall[0], 'f_score': f_score[0]}
class Evaluator:
def __init__(self, all_true, all_outs):
self.all_true = all_true
self.all_outs = all_outs
def get_entities_fr(self, ents):
all_ents = []
for s, e, lab in ents:
all_ents.append([lab, (s, e)])
return all_ents
def transform_data(self):
all_true_ent = []
all_outs_ent = []
for i, j in zip(self.all_true, self.all_outs):
e = self.get_entities_fr(i)
all_true_ent.append(e)
e = self.get_entities_fr(j)
all_outs_ent.append(e)
return all_true_ent, all_outs_ent
@torch.no_grad()
def evaluate(self):
all_true_typed, all_outs_typed = self.transform_data()
precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values()
output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n"
return output_str, f1
def is_nested(idx1, idx2):
# Return True if idx2 is nested inside idx1 or vice versa
return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1])
def has_overlapping(idx1, idx2):
overlapping = True
if idx1[:2] == idx2[:2]:
return overlapping
if (idx1[0] > idx2[1] or idx2[0] > idx1[1]):
overlapping = False
return overlapping
def has_overlapping_nested(idx1, idx2):
# Return True if idx1 and idx2 overlap, but neither is nested inside the other
if idx1[:2] == idx2[:2]:
return True
if ((idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2)) and idx1 != idx2:
return False
else:
return True
def greedy_search(spans, flat_ner=True): # start, end, class, score
if flat_ner:
has_ov = has_overlapping
else:
has_ov = has_overlapping_nested
new_list = []
span_prob = sorted(spans, key=lambda x: -x[-1])
for i in range(len(spans)):
b = span_prob[i]
flag = False
for new in new_list:
if has_ov(b[:-1], new):
flag = True
break
if not flag:
new_list.append(b[:-1])
new_list = sorted(new_list, key=lambda x: x[0])
return new_list
|