File size: 6,354 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import unittest
from onmt.modules.embeddings import Embeddings
import itertools
from copy import deepcopy
import torch
from onmt.tests.utils_for_tests import product_dict
class TestEmbeddings(unittest.TestCase):
INIT_CASES = list(
product_dict(
word_vec_size=[172],
word_vocab_size=[319],
word_padding_idx=[17],
position_encoding=[False, True],
feat_merge=["first", "concat", "sum", "mlp"],
feat_vec_exponent=[-1, 1.1, 0.7],
feat_vec_size=[0, 200],
feat_padding_idx=[[], [29], [0, 1]],
feat_vocab_sizes=[[], [39], [401, 39]],
dropout=[0, 0.5],
freeze_word_vecs=[False, True],
)
)
PARAMS = list(product_dict(batch_size=[1, 14], max_seq_len=[23]))
@classmethod
def case_is_degenerate(cls, case):
no_feats = len(case["feat_vocab_sizes"]) == 0
if case["feat_merge"] != "first" and no_feats:
return True
if case["feat_merge"] == "first" and not no_feats:
return True
if case["feat_merge"] == "concat" and case["feat_vec_exponent"] != -1:
return True
if no_feats and case["feat_vec_exponent"] != -1:
return True
if len(case["feat_vocab_sizes"]) != len(case["feat_padding_idx"]):
return True
if case["feat_vec_size"] == 0 and case["feat_vec_exponent"] <= 0:
return True
if case["feat_merge"] == "sum":
if case["feat_vec_exponent"] != -1:
return True
if case["feat_vec_size"] != 0:
return True
if case["feat_vec_size"] != 0 and case["feat_vec_exponent"] != -1:
return True
return False
@classmethod
def cases(cls):
for case in cls.INIT_CASES:
if not cls.case_is_degenerate(case):
yield case
@classmethod
def dummy_inputs(cls, params, init_case):
max_seq_len = params["max_seq_len"]
batch_size = params["batch_size"]
fv_sizes = init_case["feat_vocab_sizes"]
n_words = init_case["word_vocab_size"]
voc_sizes = [n_words] + fv_sizes
pad_idxs = [init_case["word_padding_idx"]] + init_case["feat_padding_idx"]
lengths = torch.randint(0, max_seq_len, (batch_size,))
lengths[0] = max_seq_len
inps = torch.empty((batch_size, max_seq_len, len(voc_sizes)), dtype=torch.long)
for f, (voc_size, pad_idx) in enumerate(zip(voc_sizes, pad_idxs)):
for b, len_ in enumerate(lengths):
inps[b, :len_, f] = torch.randint(0, voc_size - 1, (len_,))
inps[b, len_:, f] = pad_idx
return inps
@classmethod
def expected_shape(cls, params, init_case):
wvs = init_case["word_vec_size"]
fvs = init_case["feat_vec_size"]
nf = len(init_case["feat_vocab_sizes"])
size = wvs
if init_case["feat_merge"] not in {"sum", "mlp"}:
size += nf * fvs
return params["batch_size"], params["max_seq_len"], size
def test_embeddings_forward_shape(self):
for params, init_case in itertools.product(self.PARAMS, self.cases()):
emb = Embeddings(**init_case)
dummy_in = self.dummy_inputs(params, init_case)
res = emb(dummy_in)
expected_shape = self.expected_shape(params, init_case)
self.assertEqual(res.shape, expected_shape, init_case.__str__())
def test_embeddings_trainable_params(self):
for params, init_case in itertools.product(self.PARAMS, self.cases()):
emb = Embeddings(**init_case)
trainable_params = {
n: p for n, p in emb.named_parameters() if p.requires_grad
}
# first check there's nothing unexpectedly not trainable
for key in emb.state_dict():
if key not in trainable_params:
if (
key.endswith("emb_luts.0.weight")
and init_case["freeze_word_vecs"]
):
# ok: word embeddings shouldn't be trainable
# if word vecs are freezed
continue
if key.endswith(".pe.pe"):
# ok: positional encodings shouldn't be trainable
assert init_case["position_encoding"]
continue
else:
self.fail(
"Param {:s} is unexpectedly not " "trainable.".format(key)
)
# then check nothing unexpectedly trainable
if init_case["freeze_word_vecs"]:
self.assertFalse(
any(
trainable_param.endswith("emb_luts.0.weight")
for trainable_param in trainable_params
),
"Word embedding is trainable but word vecs are freezed.",
)
if init_case["position_encoding"]:
self.assertFalse(
any(
trainable_p.endswith(".pe.pe")
for trainable_p in trainable_params
),
"Positional encoding is trainable.",
)
def test_embeddings_trainable_params_update(self):
for params, init_case in itertools.product(self.PARAMS, self.cases()):
emb = Embeddings(**init_case)
trainable_params = {
n: p for n, p in emb.named_parameters() if p.requires_grad
}
if len(trainable_params) > 0:
old_weights = deepcopy(trainable_params)
dummy_in = self.dummy_inputs(params, init_case)
res = emb(dummy_in)
pretend_loss = res.sum()
pretend_loss.backward()
dummy_optim = torch.optim.SGD(trainable_params.values(), 1)
dummy_optim.step()
for param_name in old_weights.keys():
self.assertTrue(
trainable_params[param_name].ne(old_weights[param_name]).any(),
param_name + " " + init_case.__str__(),
)
|