File size: 1,567 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
"""Module for dynamic scoring"""
import os
import importlib
from .scorer import Scorer, build_scorers
AVAILABLE_SCORERS = {}
def get_scorers_cls(metric_names):
"""Returns a dict with scorers related to the metrics
indicated in `metric_names`."""
scorers_cls = {}
for name in metric_names:
if name not in AVAILABLE_SCORERS.keys():
raise ValueError("specified metric not supported!")
scorers_cls[name] = AVAILABLE_SCORERS[name]
return scorers_cls
__all__ = ["get_scorers_cls", "build_scorers"]
def register_scorer(metric):
"""Scorer register that can be used to add new scorer class."""
def register_scorer_cls(cls):
if metric in AVAILABLE_SCORERS.keys():
raise ValueError(
"Cannot register duplicate scorer for metric ({})".format(metric)
)
if not issubclass(cls, Scorer):
raise ValueError(
"scorer ({}: {}) must extend Scorer".format(metric, cls.__name__)
)
AVAILABLE_SCORERS[metric] = cls
return cls
return register_scorer_cls
# Auto import python files in this directory
scorer_dir = os.path.dirname(__file__)
for file in os.listdir(scorer_dir):
path = os.path.join(scorer_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
file_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("onmt.scorers." + file_name)
|