|
""" |
|
Here come the tests for attention types and their compatibility |
|
""" |
|
import unittest |
|
import torch |
|
from torch.autograd import Variable |
|
|
|
import onmt |
|
|
|
|
|
class TestAttention(unittest.TestCase): |
|
def test_masked_global_attention(self): |
|
src_len = torch.IntTensor([7, 3, 5, 2]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = src_len.size(0) |
|
dim = 20 |
|
|
|
enc_out = Variable(torch.randn(batch_size, src_len.max(), dim)) |
|
enc_final_hs = Variable(torch.randn(batch_size, dim)) |
|
|
|
attn = onmt.modules.GlobalAttention(dim) |
|
|
|
_, alignments = attn(enc_final_hs, enc_out, src_len=src_len) |
|
|
|
|
|
|
|
|
|
|