saridormi's picture
Make files submissions instructions task-specific & other small changes
6c92442
raw
history blame
1.67 kB
from typing import Dict, List
import evaluate # type: ignore[import]
from ..base_task_metrics import BaseTaskMetrics
class CMGMetrics(BaseTaskMetrics):
def __init__(self):
self.bleu = evaluate.load("sacrebleu")
self.chrf = evaluate.load("chrf")
self.rouge = evaluate.load("rouge")
self.bertscore = evaluate.load("bertscore")
self.bertscore_normalized = evaluate.load("bertscore")
def add_batch(self, predictions: List[str], references: List[str], *args, **kwargs) -> None:
self.bleu.add_batch(predictions=predictions, references=[[ref] for ref in references])
self.chrf.add_batch(predictions=predictions, references=[[ref] for ref in references])
self.rouge.add_batch(predictions=predictions, references=references)
self.bertscore.add_batch(predictions=predictions, references=references)
self.bertscore_normalized.add_batch(predictions=predictions, references=references)
def compute(self, *args, **kwargs) -> Dict[str, float]:
rouge = self.rouge.compute()
bertscore = self.bertscore.compute(lang="en")
bertscore_normalized = self.bertscore_normalized.compute(lang="en", rescale_with_baseline=True)
return {
"bleu": self.bleu.compute(tokenize="13a")["score"],
"chrf": self.chrf.compute()["score"],
"rouge1": rouge["rouge1"] * 100,
"rouge2": rouge["rouge2"] * 100,
"rougeL": rouge["rougeL"] * 100,
"bertscore": sum(bertscore["f1"]) / len(bertscore["f1"]),
"bertscore_normalized": sum(bertscore_normalized["f1"]) / len(bertscore_normalized["f1"]),
}