File size: 1,930 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
"""Module defining decoders."""
import os
import importlib
from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, StdRNNDecoder
from onmt.decoders.transformer import TransformerDecoder, TransformerLMDecoder
from onmt.decoders.cnn_decoder import CNNDecoder
str2dec = {
"rnn": StdRNNDecoder,
"ifrnn": InputFeedRNNDecoder,
"cnn": CNNDecoder,
"transformer": TransformerDecoder,
"transformer_lm": TransformerLMDecoder,
}
__all__ = [
"DecoderBase",
"TransformerDecoder",
"StdRNNDecoder",
"CNNDecoder",
"InputFeedRNNDecoder",
"str2dec",
"TransformerLMDecoder",
]
def get_decoders_cls(decoders_names):
"""Return valid encoder class indicated in `decoders_names`."""
decoders_cls = {}
for name in decoders_names:
if name not in str2dec:
raise ValueError("%s decoder not supported!" % name)
decoders_cls[name] = str2dec[name]
return decoders_cls
def register_decoder(name):
"""Encoder register that can be used to add new encoder class."""
def register_decoder_cls(cls):
if name in str2dec:
raise ValueError("Cannot register duplicate decoder ({})".format(name))
if not issubclass(cls, DecoderBase):
raise ValueError(f"decoder ({name}: {cls.__name_}) must extend DecoderBase")
str2dec[name] = cls
__all__.append(cls.__name__) # added to be complete
return cls
return register_decoder_cls
# Auto import python files in this directory
decoder_dir = os.path.dirname(__file__)
for file in os.listdir(decoder_dir):
path = os.path.join(decoder_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
file_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("onmt.decoders." + file_name)
|