Spaces:
Running
Running
# MIT License | |
# Copyright (c) 2022 Alireza Mohammadshahi | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
""" RQUGE metric. """ | |
import functools | |
from contextlib import contextmanager | |
from rquge_score.scorer import RQUGE | |
import datasets | |
from packaging import version | |
import evaluate | |
def filter_logging_context(): | |
def filter_log(record): | |
return False if "This IS expected if you are initializing" in record.msg else True | |
logger = datasets.utils.logging.get_logger("transformers.modeling_utils") | |
logger.addFilter(filter_log) | |
try: | |
yield | |
finally: | |
logger.removeFilter(filter_log) | |
_CITATION = """\ | |
@misc{mohammadshahi2022rquge, | |
title={RQUGE: Reference-Free Metric for Evaluating Question Generation by Answering the Question}, | |
author={Alireza Mohammadshahi and Thomas Scialom and Majid Yazdani and Pouya Yanki and Angela Fan and James Henderson and Marzieh Saeidi}, | |
year={2022}, | |
eprint={2211.01482}, | |
archivePrefix={arXiv}, | |
primaryClass={cs.CL} | |
} | |
""" | |
_DESCRIPTION = """\ | |
RQUGE, a Reference-free QUestion Generation Evaluation metric that can compute the quality of | |
the candidate question without requiring the access to the reference question. | |
Given the corresponding context and answer span, our metric calculates the acceptability score | |
by applying a general question-answering module, followed by a span scorer. You can find | |
more detail in the paper (https://arxiv.org/abs/2211.01482) (ACL2023). | |
""" | |
_KWARGS_DESCRIPTION = """ | |
RQUGE Metric to compute the acceptability of generated question, given the context and answer. | |
Args: | |
generated_questions (list of str): Generated/candidate questions. | |
contexts (list of str): List of contexts. | |
answers (list of str): List of reference answers. | |
qa_model (str): Path to the QA model (local path or HF model hub), default: 'allenai/unifiedqa-v2-t5-large-1363200' | |
sp_model (str): Path of span scorer model (local path or HF model hub), default: 'alirezamsh/quip-512-mocha' | |
verbose (bool): Turn on intermediate status update. | |
device (str): On which the contextual embedding model will be allocated on. | |
If this argument is None, the model lives on cuda:0 if cuda is available. | |
nthreads (int): Number of threads. | |
batch_size (int): Bert score processing batch size, | |
at least one of `model_type` or `lang`. `lang` needs to be | |
specified when `rescale_with_baseline` is True. | |
Returns: | |
score: RQUGE score. | |
Examples: | |
>>> generated_questions = ["how is the weather?"] | |
>>> contexts = ["the weather is sunny"] | |
>>> answers = ["sunny"] | |
>>> rqugescore = evaluate.load("rquge") | |
>>> results = rquge.compute(generated_questions=generated_questions, contexts=contexts, answers=answers) | |
>>> print([round(v, 2) for v in results["score"]]) | |
[5.0] | |
""" | |
class RQUGEScore(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
homepage="https://github.com/alirezamshi/RQUGE", | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=[ | |
datasets.Features( | |
{ | |
"generated_questions": datasets.Value("string", id="sequence"), | |
"contexts": datasets.Value("string", id="sequence"), | |
"answers": datasets.Value("string", id="sequence"), | |
} | |
), | |
], | |
codebase_urls=["https://github.com/alirezamshi/RQUGE"], | |
reference_urls=[ | |
"https://github.com/alirezamshi/RQUGE", | |
"https://arxiv.org/abs/2211.01482", | |
], | |
) | |
def _compute( | |
self, | |
generated_questions, | |
contexts, | |
answers, | |
qa_model="allenai/unifiedqa-v2-t5-large-1363200", | |
sp_model="alirezamsh/quip-512-mocha", | |
verbose=False, | |
device='cpu', | |
): | |
rquge_model = RQUGE(sp_scorer_path=sp_model,qa_model_path=qa_model,device=device) | |
output = [] | |
total = 0 | |
for context, question, answer in zip(contexts, generated_questions, answers): | |
score = rquge_model.scorer(context, question, answer) | |
total += score | |
output.append(score) | |
if verbose: | |
print(f'Average RQUGE score is {total/len(output)}') | |
output_dict = { | |
"mean_score": total/len(output), | |
"instance_score": output, | |
} | |
return output_dict | |