alvations commited on
Commit
9dd2404
1 Parent(s): ae3c454

added comet da

Browse files
Files changed (1) hide show
  1. cometda.py +80 -0
cometda.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+
4
+ import datasets
5
+ import evaluate
6
+ from huggingface_hub import snapshot_download, login
7
+
8
+ from comet.models.multitask.unified_metric import UnifiedMetric
9
+
10
+
11
+
12
+ _CITATION = """\
13
+ @inproceedings{rei-etal-2022-comet,
14
+ title = "{COMET}-22: Unbabel-{IST} 2022 Submission for the Metrics Shared Task",
15
+ author = "Rei, Ricardo and
16
+ C. de Souza, Jos{\'e} G. and
17
+ Alves, Duarte and
18
+ Zerva, Chrysoula and
19
+ Farinha, Ana C and
20
+ Glushkova, Taisiya and
21
+ Lavie, Alon and
22
+ Coheur, Luisa and
23
+ Martins, Andr{\'e} F. T.",
24
+ booktitle = "Proceedings of the Seventh Conference on Machine Translation (WMT)",
25
+ month = dec,
26
+ year = "2022",
27
+ address = "Abu Dhabi, United Arab Emirates (Hybrid)",
28
+ publisher = "Association for Computational Linguistics",
29
+ url = "https://aclanthology.org/2022.wmt-1.52",
30
+ pages = "578--585",
31
+ }
32
+ """
33
+
34
+
35
+ _DESCRIPTION = """\
36
+ From https://huggingface.co/Unbabel/unite-mup
37
+ """
38
+
39
+ class COMETDA(evaluate.Metric):
40
+ def _info(self):
41
+ return evaluate.MetricInfo(
42
+ description=_DESCRIPTION,
43
+ citation=_CITATION,
44
+ features=datasets.Features(
45
+ {
46
+ "predictions": datasets.Value("string"),
47
+ "references": datasets.Value("string"),
48
+ }
49
+ ),
50
+ )
51
+
52
+ def _download_and_prepare(self, dl_manager):
53
+ try:
54
+ model_checkpoint_path = next(pathlib.Path('./models--Unbabel--wmt22-cometkiwi-da/').rglob('*.ckpt'))
55
+ self.model = UnifiedMetric.load_from_checkpoint(model_checkpoint_path)
56
+ except:
57
+ model_path = snapshot_download(repo_id="Unbabel/wmt22-cometkiwi-da", cache_dir=os.path.abspath(os.path.dirname('.')))
58
+ model_checkpoint_path = f"{model_path}/checkpoints/model.ckpt"
59
+ self.model = UnifiedMetric.load_from_checkpoint(model_checkpoint_path)
60
+
61
+
62
+ def _compute(
63
+ self,
64
+ predictions,
65
+ references,
66
+ data_keys=None,
67
+ ): # Allows user to use either source inputs or reference translations as ground truth.
68
+ data = [{data_keys[0]: p, data_keys[1]: r} for p, r in zip(predictions, references)]
69
+ return {"scores": self.model.predict(data, batch_size=8).scores}
70
+
71
+
72
+ def compute_triplet(
73
+ self,
74
+ predictions,
75
+ references,
76
+ sources,
77
+ ): # Unified scores, uses sources, hypotheses and references.
78
+ data = [{"src": s, "mt": p, "ref": r} for s, p, r in zip(sources, predictions, references)]
79
+ return {"scores": self.model.predict(data, batch_size=8).metadata.unified_scores}
80
+