Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from lm_eval.api.task import Task | |
from lm_eval.api.instance import Instance | |
from lm_eval.api.registry import register_task | |
from lm_eval.api.metrics import mean | |
import datasets | |
from src.backend.tasks.cnndm import utils | |
class CnnDm(Task): | |
VERSION = 0 | |
DATASET_PATH = "cnn_dailymail" | |
DATASET_NAME = "3.0.0" | |
def __init__(self, data_dir=None, cache_dir=None, download_mode=None, config=None): | |
super().__init__(data_dir=data_dir, cache_dir=cache_dir, download_mode=download_mode, config=config) | |
print('XXX CNNDM!') | |
def has_training_docs(self): | |
return True | |
def has_validation_docs(self): | |
return True | |
def has_test_docs(self): | |
return True | |
def training_docs(self): | |
return self.dataset["train"] | |
def validation_docs(self): | |
return self.dataset["validation"] | |
def test_docs(self): | |
return self.dataset["test"] | |
def doc_to_text(self, doc): | |
return f'Document: {doc["article"]}\nSummary:' | |
def should_decontaminate(): | |
return True | |
def doc_to_decontamination_query(self, doc): | |
return doc["article"] | |
def doc_to_target(self, doc): | |
return doc["highlights"] | |
def construct_requests(self, doc, ctx, **kwargs): | |
"""Uses RequestFactory to construct Requests and returns an iterable of | |
Requests which will be sent to the LM. | |
:param doc: | |
The document as returned from training_docs, validation_docs, or test_docs. | |
:param ctx: str | |
The context string, generated by fewshot_context. This includes the natural | |
language description, as well as the few shot examples, and the question | |
part of the document for `doc`. | |
""" | |
return [ | |
Instance( | |
request_type="generate_until", | |
doc=doc, | |
arguments=(ctx, {"until": ["\n", "."]}), | |
idx=0, | |
**kwargs | |
) | |
] | |
def process_results(self, doc, results): | |
return utils.process_results(doc, results) | |
def aggregation(self): | |
""" | |
:returns: {str: [float] -> float} | |
A dictionary where keys are the names of submetrics and values are | |
functions that aggregate a list of metrics | |
""" | |
return {k: mean for k in ["rouge1", "rouge2", "rougeL"]} | |
def higher_is_better(self): | |
""" | |
:returns: {str: bool} | |
A dictionary where keys are the names of submetrics and values are | |
whether a higher value of the submetric is better | |
""" | |
return {k: True for k in ["rouge1", "rouge2", "rougeL"]} | |