File size: 348 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
import unittest
from onmt.modules.structured_attention import MatrixTree
import torch
class TestStructuredAttention(unittest.TestCase):
def test_matrix_tree_marg_pdfs_sum_to_1(self):
dtree = MatrixTree()
q = torch.rand(1, 5, 5)
marg = dtree.forward(q)
self.assertTrue(marg.sum(1).allclose(torch.tensor(1.0)))
|