|
from lm_eval.api.task import ConfigurableTask |
|
from lm_eval.api.instance import Instance |
|
|
|
|
|
from lm_eval.api.metrics import mean |
|
|
|
import torch |
|
import sacrebleu |
|
from rouge_score import rouge_scorer, scoring |
|
|
|
|
|
def bleu(refs, preds): |
|
""" |
|
Returns `t5` style BLEU scores. See the related implementation: |
|
https://github.com/google-research/text-to-text-transfer-transformer/blob/3d10afd51ba97ac29eb66ae701eca274488202f7/t5/evaluation/metrics.py#L41 |
|
|
|
:param refs: |
|
A `list` of `list` of reference `str`s. |
|
:param preds: |
|
A `list` of predicted `str`s. |
|
""" |
|
score = sacrebleu.corpus_bleu( |
|
preds, |
|
refs, |
|
smooth_method="exp", |
|
smooth_value=0.0, |
|
force=False, |
|
lowercase=False, |
|
tokenize="intl", |
|
use_effective_order=False, |
|
).score |
|
return score |
|
|
|
|
|
def rouge(refs, preds): |
|
""" |
|
Returns `t5` style ROUGE scores. See the related implementation: |
|
https://github.com/google-research/text-to-text-transfer-transformer/blob/3d10afd51ba97ac29eb66ae701eca274488202f7/t5/evaluation/metrics.py#L68 |
|
|
|
:param refs: |
|
A `list` of reference `strs`. |
|
:param preds: |
|
A `list` of predicted `strs`. |
|
""" |
|
rouge_types = ["rouge1", "rouge2", "rougeLsum"] |
|
scorer = rouge_scorer.RougeScorer(rouge_types) |
|
|
|
|
|
def _prepare_summary(summary): |
|
summary = summary.replace(" . ", ".\n") |
|
return summary |
|
|
|
|
|
aggregator = scoring.BootstrapAggregator() |
|
for ref, pred in zip(refs, preds): |
|
ref = _prepare_summary(ref) |
|
pred = _prepare_summary(pred) |
|
aggregator.add_scores(scorer.score(ref, pred)) |
|
result = aggregator.aggregate() |
|
return {type: result[type].mid.fmeasure * 100 for type in rouge_types} |
|
|
|
|
|
|
|
class CNNDMv2(ConfigurableTask): |
|
VERSION = 2 |
|
DATASET_PATH = "cnn_dailymail" |
|
DATASET_NAME = "3.0.0" |
|
|
|
def __init__(self): |
|
super().__init__( |
|
config={ |
|
"metadata": {"version": self.VERSION}, |
|
"generation_kwargs": {"do_sample": False, "temperature": 0.0, "until": ["\n", "\n\n"]}, |
|
} |
|
) |
|
self.factkb_tokenizer = None |
|
self.factkb_model = None |
|
self.bert_score = None |
|
|
|
def maybe_init_factkb(self): |
|
if self.factkb_tokenizer is None or self.factkb_model is None: |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
self.factkb_tokenizer = AutoTokenizer.from_pretrained( |
|
"roberta-base", padding="max_length", truncation=True |
|
) |
|
self.factkb_model = AutoModelForSequenceClassification.from_pretrained( |
|
"bunsenfeng/FactKB", num_labels=2, device_map="auto" |
|
) |
|
|
|
def maybe_init_bertscore(self): |
|
if self.bert_score is None: |
|
from evaluate import load |
|
|
|
self.bert_score = load("bertscore") |
|
|
|
def has_training_docs(self): |
|
return True |
|
|
|
def has_validation_docs(self): |
|
return True |
|
|
|
def has_test_docs(self): |
|
return True |
|
|
|
def training_docs(self): |
|
return self.dataset["train"] |
|
|
|
def validation_docs(self): |
|
return self.dataset["validation"] |
|
|
|
def test_docs(self): |
|
return self.dataset["test"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def doc_to_text(self, doc): |
|
return f'Article: {doc["article"]}\nSummarize the article. Summary:' |
|
|
|
@staticmethod |
|
def should_decontaminate(): |
|
return True |
|
|
|
def doc_to_decontamination_query(self, doc): |
|
return doc["article"] |
|
|
|
def doc_to_target(self, doc): |
|
return doc["highlights"] |
|
|
|
def construct_requests(self, doc, ctx, **kwargs): |
|
"""Uses RequestFactory to construct Requests and returns an iterable of |
|
Requests which will be sent to the LM. |
|
|
|
:param doc: |
|
The document as returned from training_docs, validation_docs, or test_docs. |
|
:param ctx: str |
|
The context string, generated by fewshot_context. This includes the natural |
|
language description, as well as the few shot examples, and the question |
|
part of the document for `doc`. |
|
""" |
|
|
|
return [Instance(request_type="generate_until", doc=doc, arguments=(ctx, {"until": ["\n"]}), idx=0, **kwargs)] |
|
|
|
def process_results(self, doc, results): |
|
completion = results[0] |
|
|
|
|
|
|
|
document = doc["article"] |
|
gold_summary = doc["highlights"] |
|
|
|
true_refs = [doc["highlights"]] |
|
all_refs = true_refs |
|
|
|
|
|
rouge_scores = [rouge([ref], [completion]) for ref in all_refs] |
|
|
|
rouge1_scores = [score["rouge1"] for score in rouge_scores] |
|
|
|
rouge2_scores = [score["rouge2"] for score in rouge_scores] |
|
|
|
rougeL_scores = [score["rougeLsum"] for score in rouge_scores] |
|
|
|
self.maybe_init_factkb() |
|
input_factkb = [[completion, document]] |
|
factkb_tokens = self.factkb_tokenizer( |
|
input_factkb, return_tensors="pt", padding="max_length", truncation=True |
|
).to(self.factkb_model.device) |
|
factkb_logits = self.factkb_model(**factkb_tokens).logits |
|
factkb_res = torch.softmax(factkb_logits, dim=1) |
|
|
|
self.maybe_init_bertscore() |
|
bert_score_res = self.bert_score.compute( |
|
predictions=[completion], references=[gold_summary], model_type="microsoft/deberta-xlarge-mnli", lang="en" |
|
) |
|
|
|
res = { |
|
"rouge1": rouge1_scores[0], |
|
"rouge2": rouge2_scores[0], |
|
"rougeL": rougeL_scores[0], |
|
"factKB": float(factkb_res[0][1]), |
|
"bertscore_precision": float(bert_score_res["precision"][0]), |
|
"bertscore_recall": float(bert_score_res["recall"][0]), |
|
"bertscore_f1": float(bert_score_res["f1"][0]), |
|
} |
|
|
|
return res |
|
|
|
def aggregation(self): |
|
""" |
|
:returns: {str: [float] -> float} |
|
A dictionary where keys are the names of submetrics and values are |
|
functions that aggregate a list of metrics |
|
""" |
|
return { |
|
k: mean |
|
for k in [ |
|
"rouge1", |
|
"rouge2", |
|
"rougeL", |
|
"factKB", |
|
"bertscore_precision", |
|
"bertscore_recall", |
|
"bertscore_f1", |
|
] |
|
} |
|
|
|
def higher_is_better(self): |
|
""" |
|
:returns: {str: bool} |
|
A dictionary where keys are the names of submetrics and values are |
|
whether a higher value of the submetric is better |
|
""" |
|
return { |
|
k: True |
|
for k in [ |
|
"rouge1", |
|
"rouge2", |
|
"rougeL", |
|
"factKB", |
|
"bertscore_precision", |
|
"bertscore_recall", |
|
"bertscore_f1", |
|
] |
|
} |
|
|