unite_mup / cometda.py
alvations's picture
added comet da
9dd2404
raw
history blame contribute delete
No virus
2.65 kB
import os
import pathlib
import datasets
import evaluate
from huggingface_hub import snapshot_download, login
from comet.models.multitask.unified_metric import UnifiedMetric
_CITATION = """\
@inproceedings{rei-etal-2022-comet,
title = "{COMET}-22: Unbabel-{IST} 2022 Submission for the Metrics Shared Task",
author = "Rei, Ricardo and
C. de Souza, Jos{\'e} G. and
Alves, Duarte and
Zerva, Chrysoula and
Farinha, Ana C and
Glushkova, Taisiya and
Lavie, Alon and
Coheur, Luisa and
Martins, Andr{\'e} F. T.",
booktitle = "Proceedings of the Seventh Conference on Machine Translation (WMT)",
month = dec,
year = "2022",
address = "Abu Dhabi, United Arab Emirates (Hybrid)",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2022.wmt-1.52",
pages = "578--585",
}
"""
_DESCRIPTION = """\
From https://huggingface.co/Unbabel/unite-mup
"""
class COMETDA(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
features=datasets.Features(
{
"predictions": datasets.Value("string"),
"references": datasets.Value("string"),
}
),
)
def _download_and_prepare(self, dl_manager):
try:
model_checkpoint_path = next(pathlib.Path('./models--Unbabel--wmt22-cometkiwi-da/').rglob('*.ckpt'))
self.model = UnifiedMetric.load_from_checkpoint(model_checkpoint_path)
except:
model_path = snapshot_download(repo_id="Unbabel/wmt22-cometkiwi-da", cache_dir=os.path.abspath(os.path.dirname('.')))
model_checkpoint_path = f"{model_path}/checkpoints/model.ckpt"
self.model = UnifiedMetric.load_from_checkpoint(model_checkpoint_path)
def _compute(
self,
predictions,
references,
data_keys=None,
): # Allows user to use either source inputs or reference translations as ground truth.
data = [{data_keys[0]: p, data_keys[1]: r} for p, r in zip(predictions, references)]
return {"scores": self.model.predict(data, batch_size=8).scores}
def compute_triplet(
self,
predictions,
references,
sources,
): # Unified scores, uses sources, hypotheses and references.
data = [{"src": s, "mt": p, "ref": r} for s, p, r in zip(sources, predictions, references)]
return {"scores": self.model.predict(data, batch_size=8).metadata.unified_scores}