|
from collections import defaultdict |
|
|
|
import numpy as np |
|
import torch |
|
from seqeval.metrics.v1 import _prf_divide |
|
|
|
|
|
def extract_tp_actual_correct(y_true, y_pred): |
|
entities_true = defaultdict(set) |
|
entities_pred = defaultdict(set) |
|
|
|
for type_name, (start, end), idx in y_true: |
|
entities_true[type_name].add((start, end, idx)) |
|
for type_name, (start, end), idx in y_pred: |
|
entities_pred[type_name].add((start, end, idx)) |
|
|
|
target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys())) |
|
|
|
tp_sum = np.array([], dtype=np.int32) |
|
pred_sum = np.array([], dtype=np.int32) |
|
true_sum = np.array([], dtype=np.int32) |
|
for type_name in target_names: |
|
entities_true_type = entities_true.get(type_name, set()) |
|
entities_pred_type = entities_pred.get(type_name, set()) |
|
tp_sum = np.append(tp_sum, len(entities_true_type & entities_pred_type)) |
|
pred_sum = np.append(pred_sum, len(entities_pred_type)) |
|
true_sum = np.append(true_sum, len(entities_true_type)) |
|
|
|
return pred_sum, tp_sum, true_sum, target_names |
|
|
|
|
|
def flatten_for_eval(y_true, y_pred): |
|
all_true = [] |
|
all_pred = [] |
|
|
|
for i, (true, pred) in enumerate(zip(y_true, y_pred)): |
|
all_true.extend([t + [i] for t in true]) |
|
all_pred.extend([p + [i] for p in pred]) |
|
|
|
return all_true, all_pred |
|
|
|
|
|
def compute_prf(y_true, y_pred, average='micro'): |
|
y_true, y_pred = flatten_for_eval(y_true, y_pred) |
|
|
|
pred_sum, tp_sum, true_sum, target_names = extract_tp_actual_correct(y_true, y_pred) |
|
|
|
if average == 'micro': |
|
tp_sum = np.array([tp_sum.sum()]) |
|
pred_sum = np.array([pred_sum.sum()]) |
|
true_sum = np.array([true_sum.sum()]) |
|
|
|
precision = _prf_divide( |
|
numerator=tp_sum, |
|
denominator=pred_sum, |
|
metric='precision', |
|
modifier='predicted', |
|
average=average, |
|
warn_for=('precision', 'recall', 'f-score'), |
|
zero_division='warn' |
|
) |
|
|
|
recall = _prf_divide( |
|
numerator=tp_sum, |
|
denominator=true_sum, |
|
metric='recall', |
|
modifier='true', |
|
average=average, |
|
warn_for=('precision', 'recall', 'f-score'), |
|
zero_division='warn' |
|
) |
|
|
|
denominator = precision + recall |
|
denominator[denominator == 0.] = 1 |
|
f_score = 2 * (precision * recall) / denominator |
|
|
|
return {'precision': precision[0], 'recall': recall[0], 'f_score': f_score[0]} |
|
|
|
|
|
class Evaluator: |
|
def __init__(self, all_true, all_outs): |
|
self.all_true = all_true |
|
self.all_outs = all_outs |
|
|
|
def get_entities_fr(self, ents): |
|
all_ents = [] |
|
for s, e, lab in ents: |
|
all_ents.append([lab, (s, e)]) |
|
return all_ents |
|
|
|
def transform_data(self): |
|
all_true_ent = [] |
|
all_outs_ent = [] |
|
for i, j in zip(self.all_true, self.all_outs): |
|
e = self.get_entities_fr(i) |
|
all_true_ent.append(e) |
|
e = self.get_entities_fr(j) |
|
all_outs_ent.append(e) |
|
return all_true_ent, all_outs_ent |
|
|
|
@torch.no_grad() |
|
def evaluate(self): |
|
all_true_typed, all_outs_typed = self.transform_data() |
|
precision, recall, f1 = compute_prf(all_true_typed, all_outs_typed).values() |
|
output_str = f"P: {precision:.2%}\tR: {recall:.2%}\tF1: {f1:.2%}\n" |
|
return output_str, f1 |
|
|
|
|
|
def is_nested(idx1, idx2): |
|
|
|
return (idx1[0] <= idx2[0] and idx1[1] >= idx2[1]) or (idx2[0] <= idx1[0] and idx2[1] >= idx1[1]) |
|
|
|
|
|
def has_overlapping(idx1, idx2): |
|
overlapping = True |
|
if idx1[:2] == idx2[:2]: |
|
return overlapping |
|
if (idx1[0] > idx2[1] or idx2[0] > idx1[1]): |
|
overlapping = False |
|
return overlapping |
|
|
|
|
|
def has_overlapping_nested(idx1, idx2): |
|
|
|
if idx1[:2] == idx2[:2]: |
|
return True |
|
if ((idx1[0] > idx2[1] or idx2[0] > idx1[1]) or is_nested(idx1, idx2)) and idx1 != idx2: |
|
return False |
|
else: |
|
return True |
|
|
|
|
|
def greedy_search(spans, flat_ner=True): |
|
|
|
if flat_ner: |
|
has_ov = has_overlapping |
|
else: |
|
has_ov = has_overlapping_nested |
|
|
|
new_list = [] |
|
span_prob = sorted(spans, key=lambda x: -x[-1]) |
|
for i in range(len(spans)): |
|
b = span_prob[i] |
|
flag = False |
|
for new in new_list: |
|
if has_ov(b[:-1], new): |
|
flag = True |
|
break |
|
if not flag: |
|
new_list.append(b[:-1]) |
|
new_list = sorted(new_list, key=lambda x: x[0]) |
|
return new_list |
|
|