unite_mup / unitemup.py
alvations's picture
unite mup
2dfd7fe
raw
history blame contribute delete
No virus
3.11 kB
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import datasets
import evaluate
from huggingface_hub import snapshot_download
from comet.models.multitask.unified_metric import UnifiedMetric
_CITATION = """\
@inproceedings{wan-etal-2022-unite,
title = "{U}ni{TE}: Unified Translation Evaluation",
author = "Wan, Yu and
Liu, Dayiheng and
Yang, Baosong and
Zhang, Haibo and
Chen, Boxing and
Wong, Derek and
Chao, Lidia",
booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
month = may,
year = "2022",
address = "Dublin, Ireland",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2022.acl-long.558",
doi = "10.18653/v1/2022.acl-long.558",
pages = "8117--8127",
}
"""
_DESCRIPTION = """\
From https://huggingface.co/Unbabel/unite-mup
"""
class UNITEMUP(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
features=datasets.Features(
{
"predictions": datasets.Value("string"),
"references": datasets.Value("string"),
}
),
)
def _download_and_prepare(self, dl_manager):
try:
model_checkpoint_path = next(pathlib.Path('./models--Unbabel--unite-mup/').rglob('*.ckpt'))
self.model = UnifiedMetric.load_from_checkpoint(model_checkpoint_path)
except:
model_path = snapshot_download(repo_id="Unbabel/unite-mup", cache_dir=os.path.abspath(os.path.dirname('.')))
model_checkpoint_path = f"{model_path}/checkpoints/model.ckpt"
self.model = UnifiedMetric.load_from_checkpoint(model_checkpoint_path)
def _compute(
self,
predictions,
references,
data_keys=None,
): # Allows user to use either source inputs or reference translations as ground truth.
data = [{data_keys[0]: p, data_keys[1]: r} for p, r in zip(predictions, references)]
return {"scores": self.model.predict(data, batch_size=8).scores}
def compute_triplet(
self,
predictions,
references,
sources,
): # Unified scores, uses sources, hypotheses and references.
data = [{"src": s, "mt": p, "ref": r} for s, p, r in zip(sources, predictions, references)]
return {"scores": self.model.predict(data, batch_size=8).metadata.unified_scores}