import unittest | |
from onmt.modules.structured_attention import MatrixTree | |
import torch | |
class TestStructuredAttention(unittest.TestCase): | |
def test_matrix_tree_marg_pdfs_sum_to_1(self): | |
dtree = MatrixTree() | |
q = torch.rand(1, 5, 5) | |
marg = dtree.forward(q) | |
self.assertTrue(marg.sum(1).allclose(torch.tensor(1.0))) | |