ReactSeq / onmt /tests /test_structured_attention.py
Oopstom's picture
Upload 313 files
c668e80 verified
import unittest
from onmt.modules.structured_attention import MatrixTree
import torch
class TestStructuredAttention(unittest.TestCase):
def test_matrix_tree_marg_pdfs_sum_to_1(self):
dtree = MatrixTree()
q = torch.rand(1, 5, 5)
marg = dtree.forward(q)
self.assertTrue(marg.sum(1).allclose(torch.tensor(1.0)))