File size: 6,354 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import unittest
from onmt.modules.embeddings import Embeddings

import itertools
from copy import deepcopy

import torch

from onmt.tests.utils_for_tests import product_dict


class TestEmbeddings(unittest.TestCase):
    INIT_CASES = list(
        product_dict(
            word_vec_size=[172],
            word_vocab_size=[319],
            word_padding_idx=[17],
            position_encoding=[False, True],
            feat_merge=["first", "concat", "sum", "mlp"],
            feat_vec_exponent=[-1, 1.1, 0.7],
            feat_vec_size=[0, 200],
            feat_padding_idx=[[], [29], [0, 1]],
            feat_vocab_sizes=[[], [39], [401, 39]],
            dropout=[0, 0.5],
            freeze_word_vecs=[False, True],
        )
    )
    PARAMS = list(product_dict(batch_size=[1, 14], max_seq_len=[23]))

    @classmethod
    def case_is_degenerate(cls, case):
        no_feats = len(case["feat_vocab_sizes"]) == 0
        if case["feat_merge"] != "first" and no_feats:
            return True
        if case["feat_merge"] == "first" and not no_feats:
            return True
        if case["feat_merge"] == "concat" and case["feat_vec_exponent"] != -1:
            return True
        if no_feats and case["feat_vec_exponent"] != -1:
            return True
        if len(case["feat_vocab_sizes"]) != len(case["feat_padding_idx"]):
            return True
        if case["feat_vec_size"] == 0 and case["feat_vec_exponent"] <= 0:
            return True
        if case["feat_merge"] == "sum":
            if case["feat_vec_exponent"] != -1:
                return True
            if case["feat_vec_size"] != 0:
                return True
        if case["feat_vec_size"] != 0 and case["feat_vec_exponent"] != -1:
            return True
        return False

    @classmethod
    def cases(cls):
        for case in cls.INIT_CASES:
            if not cls.case_is_degenerate(case):
                yield case

    @classmethod
    def dummy_inputs(cls, params, init_case):
        max_seq_len = params["max_seq_len"]
        batch_size = params["batch_size"]
        fv_sizes = init_case["feat_vocab_sizes"]
        n_words = init_case["word_vocab_size"]
        voc_sizes = [n_words] + fv_sizes
        pad_idxs = [init_case["word_padding_idx"]] + init_case["feat_padding_idx"]
        lengths = torch.randint(0, max_seq_len, (batch_size,))
        lengths[0] = max_seq_len
        inps = torch.empty((batch_size, max_seq_len, len(voc_sizes)), dtype=torch.long)
        for f, (voc_size, pad_idx) in enumerate(zip(voc_sizes, pad_idxs)):
            for b, len_ in enumerate(lengths):
                inps[b, :len_, f] = torch.randint(0, voc_size - 1, (len_,))
                inps[b, len_:, f] = pad_idx
        return inps

    @classmethod
    def expected_shape(cls, params, init_case):
        wvs = init_case["word_vec_size"]
        fvs = init_case["feat_vec_size"]
        nf = len(init_case["feat_vocab_sizes"])
        size = wvs
        if init_case["feat_merge"] not in {"sum", "mlp"}:
            size += nf * fvs
        return params["batch_size"], params["max_seq_len"], size

    def test_embeddings_forward_shape(self):
        for params, init_case in itertools.product(self.PARAMS, self.cases()):
            emb = Embeddings(**init_case)
            dummy_in = self.dummy_inputs(params, init_case)
            res = emb(dummy_in)
            expected_shape = self.expected_shape(params, init_case)
            self.assertEqual(res.shape, expected_shape, init_case.__str__())

    def test_embeddings_trainable_params(self):
        for params, init_case in itertools.product(self.PARAMS, self.cases()):
            emb = Embeddings(**init_case)
            trainable_params = {
                n: p for n, p in emb.named_parameters() if p.requires_grad
            }
            # first check there's nothing unexpectedly not trainable
            for key in emb.state_dict():
                if key not in trainable_params:
                    if (
                        key.endswith("emb_luts.0.weight")
                        and init_case["freeze_word_vecs"]
                    ):
                        # ok: word embeddings shouldn't be trainable
                        # if word vecs are freezed
                        continue
                    if key.endswith(".pe.pe"):
                        # ok: positional encodings shouldn't be trainable
                        assert init_case["position_encoding"]
                        continue
                    else:
                        self.fail(
                            "Param {:s} is unexpectedly not " "trainable.".format(key)
                        )
            # then check nothing unexpectedly trainable
            if init_case["freeze_word_vecs"]:
                self.assertFalse(
                    any(
                        trainable_param.endswith("emb_luts.0.weight")
                        for trainable_param in trainable_params
                    ),
                    "Word embedding is trainable but word vecs are freezed.",
                )
            if init_case["position_encoding"]:
                self.assertFalse(
                    any(
                        trainable_p.endswith(".pe.pe")
                        for trainable_p in trainable_params
                    ),
                    "Positional encoding is trainable.",
                )

    def test_embeddings_trainable_params_update(self):
        for params, init_case in itertools.product(self.PARAMS, self.cases()):
            emb = Embeddings(**init_case)
            trainable_params = {
                n: p for n, p in emb.named_parameters() if p.requires_grad
            }
            if len(trainable_params) > 0:
                old_weights = deepcopy(trainable_params)
                dummy_in = self.dummy_inputs(params, init_case)
                res = emb(dummy_in)
                pretend_loss = res.sum()
                pretend_loss.backward()
                dummy_optim = torch.optim.SGD(trainable_params.values(), 1)
                dummy_optim.step()
                for param_name in old_weights.keys():
                    self.assertTrue(
                        trainable_params[param_name].ne(old_weights[param_name]).any(),
                        param_name + " " + init_case.__str__(),
                    )