|
"""Module defining decoders.""" |
|
import os |
|
import importlib |
|
from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, StdRNNDecoder |
|
from onmt.decoders.transformer import TransformerDecoder, TransformerLMDecoder |
|
from onmt.decoders.cnn_decoder import CNNDecoder |
|
|
|
|
|
str2dec = { |
|
"rnn": StdRNNDecoder, |
|
"ifrnn": InputFeedRNNDecoder, |
|
"cnn": CNNDecoder, |
|
"transformer": TransformerDecoder, |
|
"transformer_lm": TransformerLMDecoder, |
|
} |
|
|
|
__all__ = [ |
|
"DecoderBase", |
|
"TransformerDecoder", |
|
"StdRNNDecoder", |
|
"CNNDecoder", |
|
"InputFeedRNNDecoder", |
|
"str2dec", |
|
"TransformerLMDecoder", |
|
] |
|
|
|
|
|
def get_decoders_cls(decoders_names): |
|
"""Return valid encoder class indicated in `decoders_names`.""" |
|
decoders_cls = {} |
|
for name in decoders_names: |
|
if name not in str2dec: |
|
raise ValueError("%s decoder not supported!" % name) |
|
decoders_cls[name] = str2dec[name] |
|
return decoders_cls |
|
|
|
|
|
def register_decoder(name): |
|
"""Encoder register that can be used to add new encoder class.""" |
|
|
|
def register_decoder_cls(cls): |
|
if name in str2dec: |
|
raise ValueError("Cannot register duplicate decoder ({})".format(name)) |
|
if not issubclass(cls, DecoderBase): |
|
raise ValueError(f"decoder ({name}: {cls.__name_}) must extend DecoderBase") |
|
str2dec[name] = cls |
|
__all__.append(cls.__name__) |
|
return cls |
|
|
|
return register_decoder_cls |
|
|
|
|
|
|
|
decoder_dir = os.path.dirname(__file__) |
|
for file in os.listdir(decoder_dir): |
|
path = os.path.join(decoder_dir, file) |
|
if ( |
|
not file.startswith("_") |
|
and not file.startswith(".") |
|
and (file.endswith(".py") or os.path.isdir(path)) |
|
): |
|
file_name = file[: file.find(".py")] if file.endswith(".py") else file |
|
module = importlib.import_module("onmt.decoders." + file_name) |
|
|