import os import pathlib import datasets import evaluate from huggingface_hub import snapshot_download, login from comet.models.multitask.unified_metric import UnifiedMetric _CITATION = """\ @inproceedings{rei-etal-2022-comet, title = "{COMET}-22: Unbabel-{IST} 2022 Submission for the Metrics Shared Task", author = "Rei, Ricardo and C. de Souza, Jos{\'e} G. and Alves, Duarte and Zerva, Chrysoula and Farinha, Ana C and Glushkova, Taisiya and Lavie, Alon and Coheur, Luisa and Martins, Andr{\'e} F. T.", booktitle = "Proceedings of the Seventh Conference on Machine Translation (WMT)", month = dec, year = "2022", address = "Abu Dhabi, United Arab Emirates (Hybrid)", publisher = "Association for Computational Linguistics", url = "https://aclanthology.org/2022.wmt-1.52", pages = "578--585", } """ _DESCRIPTION = """\ From https://huggingface.co/Unbabel/unite-mup """ class COMETDA(evaluate.Metric): def _info(self): return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, features=datasets.Features( { "predictions": datasets.Value("string"), "references": datasets.Value("string"), } ), ) def _download_and_prepare(self, dl_manager): try: model_checkpoint_path = next(pathlib.Path('./models--Unbabel--wmt22-cometkiwi-da/').rglob('*.ckpt')) self.model = UnifiedMetric.load_from_checkpoint(model_checkpoint_path) except: model_path = snapshot_download(repo_id="Unbabel/wmt22-cometkiwi-da", cache_dir=os.path.abspath(os.path.dirname('.'))) model_checkpoint_path = f"{model_path}/checkpoints/model.ckpt" self.model = UnifiedMetric.load_from_checkpoint(model_checkpoint_path) def _compute( self, predictions, references, data_keys=None, ): # Allows user to use either source inputs or reference translations as ground truth. data = [{data_keys[0]: p, data_keys[1]: r} for p, r in zip(predictions, references)] return {"scores": self.model.predict(data, batch_size=8).scores} def compute_triplet( self, predictions, references, sources, ): # Unified scores, uses sources, hypotheses and references. data = [{"src": s, "mt": p, "ref": r} for s, p, r in zip(sources, predictions, references)] return {"scores": self.model.predict(data, batch_size=8).metadata.unified_scores}