File size: 1,185 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import unittest
from onmt.translate import GeneratorLM
import torch
class TestGeneratorLM(unittest.TestCase):
def test_split_src_to_prevent_padding_target_prefix_is_none_when_equal_size( # noqa: E501
self,
):
src = torch.randint(0, 10, (6, 5, 1))
src_len = 5 * torch.ones(5, dtype=torch.int)
(
src,
src_len,
target_prefix,
) = GeneratorLM.split_src_to_prevent_padding(src, src_len)
self.assertIsNone(target_prefix)
def test_split_src_to_prevent_padding_target_prefix_is_ok_when_different_size( # noqa: E501
self,
):
default_length = 5
src = torch.randint(0, 10, (6, default_length, 1))
src_len = default_length * torch.ones(6, dtype=torch.int)
new_length = 4
src_len[1] = new_length
(
src,
src_len,
target_prefix,
) = GeneratorLM.split_src_to_prevent_padding(src, src_len)
self.assertTupleEqual(src.shape, (6, new_length, 1))
self.assertTupleEqual(target_prefix.shape, (6, 1, 1))
self.assertTrue(src_len.equal(new_length * torch.ones(6, dtype=torch.int)))
|