File size: 1,567 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Module for dynamic scoring"""
import os
import importlib

from .scorer import Scorer, build_scorers

AVAILABLE_SCORERS = {}


def get_scorers_cls(metric_names):
    """Returns a dict with scorers related to the metrics
    indicated in `metric_names`."""
    scorers_cls = {}
    for name in metric_names:
        if name not in AVAILABLE_SCORERS.keys():
            raise ValueError("specified metric not supported!")
        scorers_cls[name] = AVAILABLE_SCORERS[name]
    return scorers_cls


__all__ = ["get_scorers_cls", "build_scorers"]


def register_scorer(metric):
    """Scorer register that can be used to add new scorer class."""

    def register_scorer_cls(cls):
        if metric in AVAILABLE_SCORERS.keys():
            raise ValueError(
                "Cannot register duplicate scorer for metric ({})".format(metric)
            )
        if not issubclass(cls, Scorer):
            raise ValueError(
                "scorer ({}: {}) must extend Scorer".format(metric, cls.__name__)
            )
        AVAILABLE_SCORERS[metric] = cls
        return cls

    return register_scorer_cls


# Auto import python files in this directory
scorer_dir = os.path.dirname(__file__)
for file in os.listdir(scorer_dir):
    path = os.path.join(scorer_dir, file)
    if (
        not file.startswith("_")
        and not file.startswith(".")
        and (file.endswith(".py") or os.path.isdir(path))
    ):
        file_name = file[: file.find(".py")] if file.endswith(".py") else file
        module = importlib.import_module("onmt.scorers." + file_name)