File size: 1,930 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""Module defining decoders."""
import os
import importlib
from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, StdRNNDecoder
from onmt.decoders.transformer import TransformerDecoder, TransformerLMDecoder
from onmt.decoders.cnn_decoder import CNNDecoder


str2dec = {
    "rnn": StdRNNDecoder,
    "ifrnn": InputFeedRNNDecoder,
    "cnn": CNNDecoder,
    "transformer": TransformerDecoder,
    "transformer_lm": TransformerLMDecoder,
}

__all__ = [
    "DecoderBase",
    "TransformerDecoder",
    "StdRNNDecoder",
    "CNNDecoder",
    "InputFeedRNNDecoder",
    "str2dec",
    "TransformerLMDecoder",
]


def get_decoders_cls(decoders_names):
    """Return valid encoder class indicated in `decoders_names`."""
    decoders_cls = {}
    for name in decoders_names:
        if name not in str2dec:
            raise ValueError("%s decoder not supported!" % name)
        decoders_cls[name] = str2dec[name]
    return decoders_cls


def register_decoder(name):
    """Encoder register that can be used to add new encoder class."""

    def register_decoder_cls(cls):
        if name in str2dec:
            raise ValueError("Cannot register duplicate decoder ({})".format(name))
        if not issubclass(cls, DecoderBase):
            raise ValueError(f"decoder ({name}: {cls.__name_}) must extend DecoderBase")
        str2dec[name] = cls
        __all__.append(cls.__name__)  # added to be complete
        return cls

    return register_decoder_cls


# Auto import python files in this directory
decoder_dir = os.path.dirname(__file__)
for file in os.listdir(decoder_dir):
    path = os.path.join(decoder_dir, file)
    if (
        not file.startswith("_")
        and not file.startswith(".")
        and (file.endswith(".py") or os.path.isdir(path))
    ):
        file_name = file[: file.find(".py")] if file.endswith(".py") else file
        module = importlib.import_module("onmt.decoders." + file_name)