ReactSeq / onmt /decoders /__init__.py
Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
1.93 kB
"""Module defining decoders."""
import os
import importlib
from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, StdRNNDecoder
from onmt.decoders.transformer import TransformerDecoder, TransformerLMDecoder
from onmt.decoders.cnn_decoder import CNNDecoder
str2dec = {
"rnn": StdRNNDecoder,
"ifrnn": InputFeedRNNDecoder,
"cnn": CNNDecoder,
"transformer": TransformerDecoder,
"transformer_lm": TransformerLMDecoder,
}
__all__ = [
"DecoderBase",
"TransformerDecoder",
"StdRNNDecoder",
"CNNDecoder",
"InputFeedRNNDecoder",
"str2dec",
"TransformerLMDecoder",
]
def get_decoders_cls(decoders_names):
"""Return valid encoder class indicated in `decoders_names`."""
decoders_cls = {}
for name in decoders_names:
if name not in str2dec:
raise ValueError("%s decoder not supported!" % name)
decoders_cls[name] = str2dec[name]
return decoders_cls
def register_decoder(name):
"""Encoder register that can be used to add new encoder class."""
def register_decoder_cls(cls):
if name in str2dec:
raise ValueError("Cannot register duplicate decoder ({})".format(name))
if not issubclass(cls, DecoderBase):
raise ValueError(f"decoder ({name}: {cls.__name_}) must extend DecoderBase")
str2dec[name] = cls
__all__.append(cls.__name__) # added to be complete
return cls
return register_decoder_cls
# Auto import python files in this directory
decoder_dir = os.path.dirname(__file__)
for file in os.listdir(decoder_dir):
path = os.path.join(decoder_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
file_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("onmt.decoders." + file_name)