from collections import Counter from rouge_score import rouge_scorer ROUGE_TYPES = ["rouge1", "rouge2", "rougeL"] rouge_scorer = rouge_scorer.RougeScorer( ROUGE_TYPES, use_stemmer=True ) def compute_token_f1(tgt_tokens, pred_tokens, use_counts=True): if not use_counts: tgt_tokens = set(tgt_tokens) pred_tokens = set(pred_tokens) tgt_counts = Counter(tgt_tokens) pred_counts = Counter(pred_tokens) overlap = 0 for t in (set(tgt_tokens) | set(pred_tokens)): overlap += min(tgt_counts[t], pred_counts[t]) p = overlap / len(pred_tokens) if overlap > 0 else 0. r = overlap / len(tgt_tokens) if overlap > 0 else 0. f1 = (2 * p * r) / (p + r) if min(p, r) > 0 else 0. return f1