import copy import unittest import torch import pyonmttok from onmt.constants import DefaultTokens from collections import Counter import onmt import onmt.inputters import onmt.opts from onmt.model_builder import build_embeddings, build_encoder, build_decoder from onmt.utils.parse import ArgumentParser parser = ArgumentParser(description="train.py") onmt.opts.model_opts(parser) onmt.opts.distributed_opts(parser) onmt.opts._add_train_general_opts(parser) # -data option is required, but not used in this test, so dummy. opt = parser.parse_known_args(["-data", "dummy"])[0] class TestModel(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestModel, self).__init__(*args, **kwargs) self.opt = opt def get_vocabs(self): src_vocab = pyonmttok.build_vocab_from_tokens( Counter(), maximum_size=0, minimum_frequency=1, special_tokens=[ DefaultTokens.UNK, DefaultTokens.PAD, DefaultTokens.BOS, DefaultTokens.EOS, ], ) tgt_vocab = pyonmttok.build_vocab_from_tokens( Counter(), maximum_size=0, minimum_frequency=1, special_tokens=[ DefaultTokens.UNK, DefaultTokens.PAD, DefaultTokens.BOS, DefaultTokens.EOS, ], ) vocabs = {"src": src_vocab, "tgt": tgt_vocab} return vocabs def get_batch(self, source_l=3, bsize=1): # len x batch x nfeat test_src = torch.ones(bsize, source_l, 1).long() test_tgt = torch.ones(bsize, source_l, 1).long() test_length = torch.ones(bsize).fill_(source_l).long() return test_src, test_tgt, test_length def embeddings_forward(self, opt, source_l=3, bsize=1): """ Tests if the embeddings works as expected args: opt: set of options source_l: Length of generated input sentence bsize: Batchsize of generated input """ vocabs = self.get_vocabs() emb = build_embeddings(opt, vocabs) test_src, _, __ = self.get_batch(source_l=source_l, bsize=bsize) if opt.decoder_type == "transformer": input = torch.cat([test_src, test_src], 1) res = emb(input) compare_to = torch.zeros(bsize, source_l * 2, opt.src_word_vec_size) else: res = emb(test_src) compare_to = torch.zeros(bsize, source_l, opt.src_word_vec_size) self.assertEqual(res.size(), compare_to.size()) def encoder_forward(self, opt, source_l=3, bsize=1): """ Tests if the encoder works as expected args: opt: set of options source_l: Length of generated input sentence bsize: Batchsize of generated input """ if opt.hidden_size > 0: opt.enc_hid_size = opt.hidden_size vocabs = self.get_vocabs() embeddings = build_embeddings(opt, vocabs) enc = build_encoder(opt, embeddings) test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize) enc_out, hidden_t, test_length = enc(test_src, test_length) # Initialize vectors to compare size with test_hid = torch.zeros(self.opt.enc_layers, bsize, opt.enc_hid_size) test_out = torch.zeros(bsize, source_l, opt.dec_hid_size) # Ensure correct sizes and types self.assertEqual(test_hid.size(), hidden_t[0].size(), hidden_t[1].size()) self.assertEqual(test_out.size(), enc_out.size()) self.assertEqual(type(enc_out), torch.Tensor) def nmtmodel_forward(self, opt, source_l=3, bsize=1): """ Creates a nmtmodel with a custom opt function. Forwards a testbatch and checks output size. Args: opt: Namespace with options source_l: length of input sequence bsize: batchsize """ if opt.hidden_size > 0: opt.enc_hid_size = opt.hidden_size opt.dec_hid_size = opt.hidden_size vocabs = self.get_vocabs() embeddings = build_embeddings(opt, vocabs) enc = build_encoder(opt, embeddings) embeddings = build_embeddings(opt, vocabs, for_encoder=False) dec = build_decoder(opt, embeddings) model = onmt.models.model.NMTModel(enc, dec) test_src, test_tgt, test_length = self.get_batch(source_l=source_l, bsize=bsize) output, attn = model(test_src, test_tgt, test_length) outputsize = torch.zeros(bsize, source_l - 1, opt.dec_hid_size) # Make sure that output has the correct size and type self.assertEqual(output.size(), outputsize.size()) self.assertEqual(type(output), torch.Tensor) def _add_test(param_setting, methodname): """ Adds a Test to TestModel according to settings Args: param_setting: list of tuples of (param, setting) methodname: name of the method that gets called """ def test_method(self): opt = copy.deepcopy(self.opt) if param_setting: for param, setting in param_setting: setattr(opt, param, setting) ArgumentParser.update_model_opts(opt) getattr(self, methodname)(opt) if param_setting: name = "test_" + methodname + "_" + "_".join(str(param_setting).split()) else: name = "test_" + methodname + "_standard" setattr(TestModel, name, test_method) test_method.__name__ = name """ TEST PARAMETERS """ opt.brnn = False test_embeddings = [[], [("decoder_type", "transformer")]] for p in test_embeddings: _add_test(p, "embeddings_forward") tests_encoder = [ [], [("encoder_type", "mean")], # [('encoder_type', 'transformer'), # ('word_vec_size', 16), ('hidden_size', 16)], [], ] for p in tests_encoder: _add_test(p, "encoder_forward") tests_nmtmodel = [ [("rnn_type", "GRU")], [("layers", 10)], [("input_feed", 0)], [ ("decoder_type", "transformer"), ("encoder_type", "transformer"), ("src_word_vec_size", 16), ("tgt_word_vec_size", 16), ("hidden_size", 16), ], [ ("decoder_type", "transformer"), ("encoder_type", "transformer"), ("src_word_vec_size", 16), ("tgt_word_vec_size", 16), ("hidden_size", 16), ("position_encoding", True), ], [("coverage_attn", True)], [("copy_attn", True)], [("global_attention", "mlp")], [("context_gate", "both")], [("context_gate", "target")], [("context_gate", "source")], [("encoder_type", "brnn"), ("brnn_merge", "sum")], [("encoder_type", "brnn")], [("decoder_type", "cnn"), ("encoder_type", "cnn")], [("encoder_type", "rnn"), ("global_attention", None)], [ ("encoder_type", "rnn"), ("global_attention", None), ("copy_attn", True), ("copy_attn_type", "general"), ], [ ("encoder_type", "rnn"), ("global_attention", "mlp"), ("copy_attn", True), ("copy_attn_type", "general"), ], [], ] if onmt.modules.sru.check_sru_requirement(): # """ Only do SRU test if requirment is safisfied. """ # SRU doesn't support input_feed. tests_nmtmodel.append([("rnn_type", "SRU"), ("input_feed", 0)]) for p in tests_nmtmodel: _add_test(p, "nmtmodel_forward")